Merge branch 'develop' into matthew/preview_urls

This commit is contained in:
Matthew Hodgson 2016-03-27 22:54:42 +01:00
commit d9d48aad2d
217 changed files with 8815 additions and 7174 deletions

View File

@ -51,3 +51,6 @@ Steven Hammerton <steven.hammerton at openmarket.com>
Mads Robin Christensen <mads at v42 dot dk> Mads Robin Christensen <mads at v42 dot dk>
* CentOS 7 installation instructions. * CentOS 7 installation instructions.
Florent Violleau <floviolleau at gmail dot com>
* Add Raspberry Pi installation instructions and general troubleshooting items

View File

@ -1,3 +1,90 @@
Changes in synapse v0.13.3 (2016-02-11)
=======================================
* Fix bug where ``/sync`` would occasionally return events in the wrong room.
Changes in synapse v0.13.2 (2016-02-11)
=======================================
* Fix bug where ``/events`` would fail to skip some events if there had been
more events than the limit specified since the last request (PR #570)
Changes in synapse v0.13.1 (2016-02-10)
=======================================
* Bump matrix-angular-sdk (matrix web console) dependency to 0.6.8 to
pull in the fix for SYWEB-361 so that the default client can display
HTML messages again(!)
Changes in synapse v0.13.0 (2016-02-10)
=======================================
This version includes an upgrade of the schema, specifically adding an index to
the ``events`` table. This may cause synapse to pause for several minutes the
first time it is started after the upgrade.
Changes:
* Improve general performance (PR #540, #543. #544, #54, #549, #567)
* Change guest user ids to be incrementing integers (PR #550)
* Improve performance of public room list API (PR #552)
* Change profile API to omit keys rather than return null (PR #557)
* Add ``/media/r0`` endpoint prefix, which is equivalent to ``/media/v1/``
(PR #595)
Bug fixes:
* Fix bug with upgrading guest accounts where it would fail if you opened the
registration email on a different device (PR #547)
* Fix bug where unread count could be wrong (PR #568)
Changes in synapse v0.12.1-rc1 (2016-01-29)
===========================================
Features:
* Add unread notification counts in ``/sync`` (PR #456)
* Add support for inviting 3pids in ``/createRoom`` (PR #460)
* Add ability for guest accounts to upgrade (PR #462)
* Add ``/versions`` API (PR #468)
* Add ``event`` to ``/context`` API (PR #492)
* Add specific error code for invalid user names in ``/register`` (PR #499)
* Add support for push badge counts (PR #507)
* Add support for non-guest users to peek in rooms using ``/events`` (PR #510)
Changes:
* Change ``/sync`` so that guest users only get rooms they've joined (PR #469)
* Change to require unbanning before other membership changes (PR #501)
* Change default push rules to notify for all messages (PR #486)
* Change default push rules to not notify on membership changes (PR #514)
* Change default push rules in one to one rooms to only notify for events that
are messages (PR #529)
* Change ``/sync`` to reject requests with a ``from`` query param (PR #512)
* Change server manhole to use SSH rather than telnet (PR #473)
* Change server to require AS users to be registered before use (PR #487)
* Change server not to start when ASes are invalidly configured (PR #494)
* Change server to require ID and ``as_token`` to be unique for AS's (PR #496)
* Change maximum pagination limit to 1000 (PR #497)
Bug fixes:
* Fix bug where ``/sync`` didn't return when something under the leave key
changed (PR #461)
* Fix bug where we returned smaller rather than larger than requested
thumbnails when ``method=crop`` (PR #464)
* Fix thumbnails API to only return cropped thumbnails when asking for a
cropped thumbnail (PR #475)
* Fix bug where we occasionally still logged access tokens (PR #477)
* Fix bug where ``/events`` would always return immediately for guest users
(PR #480)
* Fix bug where ``/sync`` unexpectedly returned old left rooms (PR #481)
* Fix enabling and disabling push rules (PR #498)
* Fix bug where ``/register`` returned 500 when given unicode username
(PR #513)
Changes in synapse v0.12.0 (2016-01-04) Changes in synapse v0.12.0 (2016-01-04)
======================================= =======================================

View File

@ -21,5 +21,6 @@ recursive-include synapse/static *.html
recursive-include synapse/static *.js recursive-include synapse/static *.js
exclude jenkins.sh exclude jenkins.sh
exclude jenkins*.sh
prune demo/etc prune demo/etc

View File

@ -125,6 +125,15 @@ Installing prerequisites on Mac OS X::
sudo easy_install pip sudo easy_install pip
sudo pip install virtualenv sudo pip install virtualenv
Installing prerequisites on Raspbian::
sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools sqlite3 \
libssl-dev python-virtualenv libjpeg-dev
sudo pip install --upgrade pip
sudo pip install --upgrade ndg-httpsclient
sudo pip install --upgrade virtualenv
To install the synapse homeserver run:: To install the synapse homeserver run::
virtualenv -p python2.7 ~/.synapse virtualenv -p python2.7 ~/.synapse
@ -310,6 +319,18 @@ may need to manually upgrade it::
sudo pip install --upgrade pip sudo pip install --upgrade pip
Installing may fail with ``Could not find any downloads that satisfy the requirement pymacaroons-pynacl (from matrix-synapse==0.12.0)``.
You can fix this by manually upgrading pip and virtualenv::
sudo pip install --upgrade virtualenv
You can next rerun ``virtualenv -p python2.7 synapse`` to update the virtual env.
Installing may fail during installing virtualenv with ``InsecurePlatformWarning: A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.``
You can fix this by manually installing ndg-httpsclient::
pip install --upgrade ndg-httpsclient
Installing may fail with ``mock requires setuptools>=17.1. Aborting installation``. Installing may fail with ``mock requires setuptools>=17.1. Aborting installation``.
You can fix this by upgrading setuptools:: You can fix this by upgrading setuptools::
@ -504,7 +525,6 @@ Logging In To An Existing Account
Just enter the ``@localpart:my.domain.here`` Matrix user ID and password into Just enter the ``@localpart:my.domain.here`` Matrix user ID and password into
the form and click the Login button. the form and click the Login button.
Identity Servers Identity Servers
================ ================
@ -524,6 +544,26 @@ as the primary means of identity and E2E encryption is not complete. As such,
we are running a single identity server (https://matrix.org) at the current we are running a single identity server (https://matrix.org) at the current
time. time.
Password reset
==============
If a user has registered an email address to their account using an identity
server, they can request a password-reset token via clients such as Vector.
A manual password reset can be done via direct database access as follows.
First calculate the hash of the new password:
$ source ~/.synapse/bin/activate
$ ./scripts/hash_password
Password:
Confirm password:
$2a$12$xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
Then update the `users` table in the database:
UPDATE users SET password_hash='$2a$12$xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
WHERE name='@test:test.com';
Where's the spec?! Where's the spec?!
================== ==================
@ -545,3 +585,20 @@ Building internal API documentation::
python setup.py build_sphinx python setup.py build_sphinx
Halp!! Synapse eats all my RAM!
===============================
Synapse's architecture is quite RAM hungry currently - we deliberately
cache a lot of recent room data and metadata in RAM in order to speed up
common requests. We'll improve this in future, but for now the easiest
way to either reduce the RAM usage (at the risk of slowing things down)
is to set the almost-undocumented ``SYNAPSE_CACHE_FACTOR`` environment
variable. Roughly speaking, a SYNAPSE_CACHE_FACTOR of 1.0 will max out
at around 3-4GB of resident memory - this is what we currently run the
matrix.org on. The default setting is currently 0.1, which is probably
around a ~700MB footprint. You can dial it down further to 0.02 if
desired, which targets roughly ~512MB. Conversely you can dial it up if
you need performance for lots of users and have a box with a lot of RAM.

22
jenkins-flake8.sh Executable file
View File

@ -0,0 +1,22 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
export PEP8SUFFIX="--output-file=violations.flake8.log"
rm .coverage* || echo "No coverage files to remove"
tox -e packaging -e pep8

61
jenkins-postgres.sh Executable file
View File

@ -0,0 +1,61 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
$TOX_BIN/pip install psycopg2
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000}
./jenkins/prep_sytest_for_postgres.sh
echo >&2 "Running sytest with PostgreSQL";
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--port-base $PORT_BASE
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

55
jenkins-sqlite.sh Executable file
View File

@ -0,0 +1,55 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8500}
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--port-base $PORT_BASE
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

25
jenkins-unittests.sh Executable file
View File

@ -0,0 +1,25 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox -e py27

View File

@ -1,6 +1,11 @@
#!/bin/bash -eu #!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml # Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit" export TRIAL_FLAGS="--reporter=subunit"
@ -26,7 +31,7 @@ TOX_BIN=$WORKSPACE/.tox/py27/bin
if [[ ! -e .sytest-base ]]; then if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else else
(cd .sytest-base; git fetch) (cd .sytest-base; git fetch -p)
fi fi
rm -rf sytest rm -rf sytest
@ -52,7 +57,7 @@ RUN_POSTGRES=""
for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do
if psql synapse_jenkins_$port <<< ""; then if psql synapse_jenkins_$port <<< ""; then
RUN_POSTGRES=$RUN_POSTGRES:$port RUN_POSTGRES="$RUN_POSTGRES:$port"
cat > localhost-$port/database.yaml << EOF cat > localhost-$port/database.yaml << EOF
name: psycopg2 name: psycopg2
args: args:
@ -62,7 +67,7 @@ EOF
done done
# Run if both postgresql databases exist # Run if both postgresql databases exist
if test $RUN_POSTGRES = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then if test "$RUN_POSTGRES" = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then
echo >&2 "Running sytest with PostgreSQL"; echo >&2 "Running sytest with PostgreSQL";
$TOX_BIN/pip install psycopg2 $TOX_BIN/pip install psycopg2
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \ ./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \

View File

@ -86,9 +86,12 @@ def used_names(prefix, item, defs, names):
for name, funcs in defs.get('class', {}).items(): for name, funcs in defs.get('class', {}).items():
used_names(prefix + name + ".", name, funcs, names) used_names(prefix + name + ".", name, funcs, names)
path = prefix.rstrip('.')
for used in defs.get('uses', ()): for used in defs.get('uses', ()):
if used in names: if used in names:
names[used].setdefault('used', {}).setdefault(item, []).append(prefix.rstrip('.')) if item:
names[item].setdefault('uses', []).append(used)
names[used].setdefault('used', {}).setdefault(item, []).append(path)
if __name__ == '__main__': if __name__ == '__main__':
@ -113,6 +116,10 @@ if __name__ == '__main__':
"--referrers", default=0, type=int, "--referrers", default=0, type=int,
help="Include referrers up to the given depth" help="Include referrers up to the given depth"
) )
parser.add_argument(
"--referred", default=0, type=int,
help="Include referred down to the given depth"
)
parser.add_argument( parser.add_argument(
"--format", default="yaml", "--format", default="yaml",
help="Output format, one of 'yaml' or 'dot'" help="Output format, one of 'yaml' or 'dot'"
@ -161,6 +168,20 @@ if __name__ == '__main__':
continue continue
result[name] = definition result[name] = definition
referred_depth = args.referred
referred = set()
while referred_depth:
referred_depth -= 1
for entry in result.values():
for uses in entry.get("uses", ()):
referred.add(uses)
for name, definition in names.items():
if not name in referred:
continue
if ignore and any(pattern.match(name) for pattern in ignore):
continue
result[name] = definition
if args.format == 'yaml': if args.format == 'yaml':
yaml.dump(result, sys.stdout, default_flow_style=False) yaml.dump(result, sys.stdout, default_flow_style=False)
elif args.format == 'dot': elif args.format == 'dot':

24
scripts-dev/dump_macaroon.py Executable file
View File

@ -0,0 +1,24 @@
#!/usr/bin/env python2
import pymacaroons
import sys
if len(sys.argv) == 1:
sys.stderr.write("usage: %s macaroon [key]\n" % (sys.argv[0],))
sys.exit(1)
macaroon_string = sys.argv[1]
key = sys.argv[2] if len(sys.argv) > 2 else None
macaroon = pymacaroons.Macaroon.deserialize(macaroon_string)
print macaroon.inspect()
print ""
verifier = pymacaroons.Verifier()
verifier.satisfy_general(lambda c: True)
try:
verifier.verify(macaroon, key)
print "Signature is correct"
except Exception as e:
print e.message

View File

@ -0,0 +1,62 @@
#! /usr/bin/python
import ast
import argparse
import os
import sys
import yaml
PATTERNS_V1 = []
PATTERNS_V2 = []
RESULT = {
"v1": PATTERNS_V1,
"v2": PATTERNS_V2,
}
class CallVisitor(ast.NodeVisitor):
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
name = node.func.id
else:
return
if name == "client_path_patterns":
PATTERNS_V1.append(node.args[0].s)
elif name == "client_v2_patterns":
PATTERNS_V2.append(node.args[0].s)
def find_patterns_in_code(input_code):
input_ast = ast.parse(input_code)
visitor = CallVisitor()
visitor.visit(input_ast)
def find_patterns_in_file(filepath):
with open(filepath) as f:
find_patterns_in_code(f.read())
parser = argparse.ArgumentParser(description='Find url patterns.')
parser.add_argument(
"directories", nargs='+', metavar="DIR",
help="Directories to search for definitions"
)
args = parser.parse_args()
for directory in args.directories:
for root, dirs, files in os.walk(directory):
for filename in files:
if filename.endswith(".py"):
filepath = os.path.join(root, filename)
find_patterns_in_file(filepath)
PATTERNS_V1.sort()
PATTERNS_V2.sort()
yaml.dump(RESULT, sys.stdout, default_flow_style=False)

View File

@ -0,0 +1,67 @@
import requests
import collections
import sys
import time
import json
Entry = collections.namedtuple("Entry", "name position rows")
ROW_TYPES = {}
def row_type_for_columns(name, column_names):
column_names = tuple(column_names)
row_type = ROW_TYPES.get((name, column_names))
if row_type is None:
row_type = collections.namedtuple(name, column_names)
ROW_TYPES[(name, column_names)] = row_type
return row_type
def parse_response(content):
streams = json.loads(content)
result = {}
for name, value in streams.items():
row_type = row_type_for_columns(name, value["field_names"])
position = value["position"]
rows = [row_type(*row) for row in value["rows"]]
result[name] = Entry(name, position, rows)
return result
def replicate(server, streams):
return parse_response(requests.get(
server + "/_synapse/replication",
verify=False,
params=streams
).content)
def main():
server = sys.argv[1]
streams = None
while not streams:
try:
streams = {
row.name: row.position
for row in replicate(server, {"streams":"-1"})["streams"].rows
}
except requests.exceptions.ConnectionError as e:
time.sleep(0.1)
while True:
try:
results = replicate(server, streams)
except:
sys.stdout.write("connection_lost("+ repr(streams) + ")\n")
break
for update in results.values():
for row in update.rows:
sys.stdout.write(repr(row) + "\n")
streams[update.name] = update.position
if __name__=='__main__':
main()

View File

@ -1 +0,0 @@
perl -MCrypt::Random -MCrypt::Eksblowfish::Bcrypt -e 'print Crypt::Eksblowfish::Bcrypt::bcrypt("secret", "\$2\$12\$" . Crypt::Eksblowfish::Bcrypt::en_base64(Crypt::Random::makerandom_octet(Length=>16)))."\n"'

39
scripts/hash_password Executable file
View File

@ -0,0 +1,39 @@
#!/usr/bin/env python
import argparse
import bcrypt
import getpass
bcrypt_rounds=12
def prompt_for_pass():
password = getpass.getpass("Password: ")
if not password:
raise Exception("Password cannot be blank.")
confirm_password = getpass.getpass("Confirm password: ")
if password != confirm_password:
raise Exception("Passwords do not match.")
return password
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Calculate the hash of a new password, so that passwords"
" can be reset")
parser.add_argument(
"-p", "--password",
default=None,
help="New password for user. Will prompt if omitted.",
)
args = parser.parse_args()
password = args.password
if not password:
password = prompt_for_pass()
print bcrypt.hashpw(password, bcrypt.gensalt(bcrypt_rounds))

View File

@ -309,8 +309,8 @@ class Porter(object):
**self.postgres_config["args"] **self.postgres_config["args"]
) )
sqlite_engine = create_engine("sqlite3") sqlite_engine = create_engine(FakeConfig(sqlite_config))
postgres_engine = create_engine("psycopg2") postgres_engine = create_engine(FakeConfig(postgres_config))
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
self.postgres_store = Store(postgres_db_pool, postgres_engine) self.postgres_store = Store(postgres_db_pool, postgres_engine)
@ -792,3 +792,8 @@ if __name__ == "__main__":
if end_error_exec_info: if end_error_exec_info:
exc_type, exc_value, exc_traceback = end_error_exec_info exc_type, exc_value, exc_traceback = end_error_exec_info
traceback.print_exception(exc_type, exc_value, exc_traceback) traceback.print_exception(exc_type, exc_value, exc_traceback)
class FakeConfig:
def __init__(self, database_config):
self.database_config = database_config

View File

@ -16,3 +16,4 @@ ignore =
[flake8] [flake8]
max-line-length = 90 max-line-length = 90
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.12.0" __version__ = "0.13.3"

View File

@ -24,6 +24,7 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, RoomID, UserID, EventID from synapse.types import Requester, RoomID, UserID, EventID
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
import logging import logging
@ -433,16 +434,19 @@ class Auth(object):
if event.user_id != invite_event.user_id: if event.user_id != invite_event.user_id:
return False return False
try:
public_key = invite_event.content["public_key"]
if signed["mxid"] != event.state_key: if signed["mxid"] != event.state_key:
return False return False
if signed["token"] != token: if signed["token"] != token:
return False return False
for public_key_object in self.get_public_keys(invite_event):
public_key = public_key_object["public_key"]
try:
for server, signature_block in signed["signatures"].items(): for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items(): for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"): if not key_name.startswith("ed25519:"):
return False continue
verify_key = decode_verify_key_bytes( verify_key = decode_verify_key_bytes(
key_name, key_name,
decode_base64(public_key) decode_base64(public_key)
@ -454,10 +458,22 @@ class Auth(object):
# The caller is responsible for checking that the signing # The caller is responsible for checking that the signing
# server has not revoked that public key. # server has not revoked that public key.
return True return True
return False
except (KeyError, SignatureVerifyException,): except (KeyError, SignatureVerifyException,):
continue
return False return False
def get_public_keys(self, invite_event):
public_keys = []
if "public_key" in invite_event.content:
o = {
"public_key": invite_event.content["public_key"],
}
if "key_validity_url" in invite_event.content:
o["key_validity_url"] = invite_event.content["key_validity_url"]
public_keys.append(o)
public_keys.extend(invite_event.content.get("public_keys", []))
return public_keys
def _get_power_level_event(self, auth_events): def _get_power_level_event(self, auth_events):
key = (EventTypes.PowerLevels, "", ) key = (EventTypes.PowerLevels, "", )
return auth_events.get(key) return auth_events.get(key)
@ -518,7 +534,7 @@ class Auth(object):
) )
access_token = request.args["access_token"][0] access_token = request.args["access_token"][0]
user_info = yield self._get_user_by_access_token(access_token) user_info = yield self.get_user_by_access_token(access_token)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
is_guest = user_info["is_guest"] is_guest = user_info["is_guest"]
@ -529,7 +545,8 @@ class Auth(object):
default=[""] default=[""]
)[0] )[0]
if user and access_token and ip_addr: if user and access_token and ip_addr:
self.store.insert_client_ip( preserve_context_over_fn(
self.store.insert_client_ip,
user=user, user=user,
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,
@ -578,7 +595,7 @@ class Auth(object):
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_user_by_access_token(self, token): def get_user_by_access_token(self, token):
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:
@ -696,6 +713,7 @@ class Auth(object):
def _look_up_user_by_access_token(self, token): def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token) ret = yield self.store.get_user_by_access_token(token)
if not ret: if not ret:
logger.warn("Unrecognised access token - not in store: %s" % (token,))
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
@ -713,6 +731,7 @@ class Auth(object):
token = request.args["access_token"][0] token = request.args["access_token"][0]
service = yield self.store.get_app_service_by_token(token) service = yield self.store.get_app_service_by_token(token)
if not service: if not service:
logger.warn("Unrecognised appservice access token: %s" % (token,))
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.", "Unrecognised access token.",

View File

@ -32,7 +32,6 @@ class PresenceState(object):
OFFLINE = u"offline" OFFLINE = u"offline"
UNAVAILABLE = u"unavailable" UNAVAILABLE = u"unavailable"
ONLINE = u"online" ONLINE = u"online"
FREE_FOR_CHAT = u"free_for_chat"
class JoinRules(object): class JoinRules(object):

View File

@ -15,6 +15,8 @@
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.types import UserID, RoomID from synapse.types import UserID, RoomID
import ujson as json
class Filtering(object): class Filtering(object):
@ -28,14 +30,14 @@ class Filtering(object):
return result return result
def add_user_filter(self, user_localpart, user_filter): def add_user_filter(self, user_localpart, user_filter):
self._check_valid_filter(user_filter) self.check_valid_filter(user_filter)
return self.store.add_user_filter(user_localpart, user_filter) return self.store.add_user_filter(user_localpart, user_filter)
# TODO(paul): surely we should probably add a delete_user_filter or # TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for # replace_user_filter at some point? There's no REST API specified for
# them however # them however
def _check_valid_filter(self, user_filter_json): def check_valid_filter(self, user_filter_json):
"""Check if the provided filter is valid. """Check if the provided filter is valid.
This inspects all definitions contained within the filter. This inspects all definitions contained within the filter.
@ -129,52 +131,58 @@ class Filtering(object):
class FilterCollection(object): class FilterCollection(object):
def __init__(self, filter_json): def __init__(self, filter_json):
self.filter_json = filter_json self._filter_json = filter_json
room_filter_json = self.filter_json.get("room", {}) room_filter_json = self._filter_json.get("room", {})
self.room_filter = Filter({ self._room_filter = Filter({
k: v for k, v in room_filter_json.items() k: v for k, v in room_filter_json.items()
if k in ("rooms", "not_rooms") if k in ("rooms", "not_rooms")
}) })
self.room_timeline_filter = Filter(room_filter_json.get("timeline", {})) self._room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
self.room_state_filter = Filter(room_filter_json.get("state", {})) self._room_state_filter = Filter(room_filter_json.get("state", {}))
self.room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {})) self._room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {}))
self.room_account_data = Filter(room_filter_json.get("account_data", {})) self._room_account_data = Filter(room_filter_json.get("account_data", {}))
self.presence_filter = Filter(self.filter_json.get("presence", {})) self._presence_filter = Filter(filter_json.get("presence", {}))
self.account_data = Filter(self.filter_json.get("account_data", {})) self._account_data = Filter(filter_json.get("account_data", {}))
self.include_leave = self.filter_json.get("room", {}).get( self.include_leave = filter_json.get("room", {}).get(
"include_leave", False "include_leave", False
) )
def __repr__(self):
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
def get_filter_json(self):
return self._filter_json
def timeline_limit(self): def timeline_limit(self):
return self.room_timeline_filter.limit() return self._room_timeline_filter.limit()
def presence_limit(self): def presence_limit(self):
return self.presence_filter.limit() return self._presence_filter.limit()
def ephemeral_limit(self): def ephemeral_limit(self):
return self.room_ephemeral_filter.limit() return self._room_ephemeral_filter.limit()
def filter_presence(self, events): def filter_presence(self, events):
return self.presence_filter.filter(events) return self._presence_filter.filter(events)
def filter_account_data(self, events): def filter_account_data(self, events):
return self.account_data.filter(events) return self._account_data.filter(events)
def filter_room_state(self, events): def filter_room_state(self, events):
return self.room_state_filter.filter(self.room_filter.filter(events)) return self._room_state_filter.filter(self._room_filter.filter(events))
def filter_room_timeline(self, events): def filter_room_timeline(self, events):
return self.room_timeline_filter.filter(self.room_filter.filter(events)) return self._room_timeline_filter.filter(self._room_filter.filter(events))
def filter_room_ephemeral(self, events): def filter_room_ephemeral(self, events):
return self.room_ephemeral_filter.filter(self.room_filter.filter(events)) return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
def filter_room_account_data(self, events): def filter_room_account_data(self, events):
return self.room_account_data.filter(self.room_filter.filter(events)) return self._room_account_data.filter(self._room_filter.filter(events))
class Filter(object): class Filter(object):
@ -187,18 +195,19 @@ class Filter(object):
Returns: Returns:
bool: True if the event matches bool: True if the event matches
""" """
if isinstance(event, dict): sender = event.get("sender", None)
if not sender:
# Presence events have their 'sender' in content.user_id
content = event.get("content")
# account_data has been allowed to have non-dict content, so check type first
if isinstance(content, dict):
sender = content.get("user_id")
return self.check_fields( return self.check_fields(
event.get("room_id", None), event.get("room_id", None),
event.get("sender", None), sender,
event.get("type", None), event.get("type", None),
) )
else:
return self.check_fields(
getattr(event, "room_id", None),
getattr(event, "sender", None),
event.type,
)
def check_fields(self, room_id, sender, event_type): def check_fields(self, room_id, sender, event_type):
"""Checks whether the filter matches the given event fields. """Checks whether the filter matches the given event fields.
@ -258,3 +267,6 @@ def _matches_wildcard(actual_value, filter_value):
return actual_value.startswith(type_prefix) return actual_value.startswith(type_prefix)
else: else:
return actual_value == filter_value return actual_value == filter_value
DEFAULT_FILTER_COLLECTION = FilterCollection({})

View File

@ -23,5 +23,6 @@ WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content" CONTENT_REPO_PREFIX = "/_matrix/content"
SERVER_KEY_PREFIX = "/_matrix/key/v1" SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/v1" MEDIA_PREFIX = "/_matrix/media/r0"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1" APP_SERVICE_PREFIX = "/_matrix/appservice/v1"

View File

@ -12,3 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
sys.dont_write_bytecode = True
from synapse.python_dependencies import (
check_requirements, MissingRequirementError
) # NOQA
try:
check_requirements()
except MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (e.message,),
"To install run:",
" pip install --upgrade --force \"%s\"" % (e.dependency,),
"",
])
sys.stderr.writelines(message)
sys.exit(1)

View File

@ -14,27 +14,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys import synapse
from synapse.rest import ClientRestResource
import contextlib
import logging
import os
import re
import resource
import subprocess
import sys
import time
from synapse.config._base import ConfigError
sys.dont_write_bytecode = True
from synapse.python_dependencies import ( from synapse.python_dependencies import (
check_requirements, DEPENDENCY_LINKS, MissingRequirementError check_requirements, DEPENDENCY_LINKS
) )
if __name__ == '__main__': from synapse.rest import ClientRestResource
try:
check_requirements()
except MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (e.message,),
"To install run:",
" pip install --upgrade --force \"%s\"" % (e.dependency,),
"",
])
sys.stderr.writelines(message)
sys.exit(1)
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import are_all_users_on_domain from synapse.storage import are_all_users_on_domain
from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.storage.prepare_database import UpgradeDatabaseException
@ -50,41 +46,30 @@ from twisted.cred import checkers, portal
from twisted.internet import reactor, task, defer from twisted.internet import reactor, task, defer
from twisted.application import service from twisted.application import service
from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File from twisted.web.static import File
from twisted.web.server import Site, GzipEncoderFactory, Request from twisted.web.server import Site, GzipEncoderFactory, Request
from synapse.http.server import JsonResource, RootRedirect from synapse.http.server import RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.rest.key.v1.server_key_resource import LocalKey from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import ( from synapse.api.urls import (
FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX, SERVER_KEY_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
SERVER_KEY_V2_PREFIX, SERVER_KEY_V2_PREFIX,
) )
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer
from synapse import events from synapse import events
from daemonize import Daemonize from daemonize import Daemonize
import synapse
import contextlib
import logging
import os
import re
import resource
import subprocess
import time
logger = logging.getLogger("synapse.app.homeserver") logger = logging.getLogger("synapse.app.homeserver")
@ -95,19 +80,8 @@ def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()]) return EncodingResourceWrapper(r, [GzipEncoderFactory()])
class SynapseHomeServer(HomeServer): def build_resource_for_web_client(hs):
webclient_path = hs.get_config().web_client_location
def build_http_client(self):
return MatrixFederationHttpClient(self)
def build_client_resource(self):
return ClientRestResource(self)
def build_resource_for_federation(self):
return JsonResource(self)
def build_resource_for_web_client(self):
webclient_path = self.get_config().web_client_location
if not webclient_path: if not webclient_path:
try: try:
import syweb import syweb
@ -135,40 +109,8 @@ class SynapseHomeServer(HomeServer):
# return GzipFile(webclient_path) # TODO configurable? # return GzipFile(webclient_path) # TODO configurable?
return File(webclient_path) # TODO configurable? return File(webclient_path) # TODO configurable?
def build_resource_for_static_content(self):
# This is old and should go away: not going to bother adding gzip
return File(
os.path.join(os.path.dirname(synapse.__file__), "static")
)
def build_resource_for_content_repo(self):
return ContentRepoResource(
self, self.config.uploads_path, self.auth, self.content_addr
)
def build_resource_for_media_repository(self):
return MediaRepositoryResource(self)
def build_resource_for_server_key(self):
return LocalKey(self)
def build_resource_for_server_key_v2(self):
return KeyApiV2Resource(self)
def build_resource_for_metrics(self):
if self.get_config().enable_metrics:
return MetricsResource(self)
else:
return None
def build_db_pool(self):
name = self.db_config["name"]
return adbapi.ConnectionPool(
name,
**self.db_config.get("args", {})
)
class SynapseHomeServer(HomeServer):
def _listener_http(self, config, listener_config): def _listener_http(self, config, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", "") bind_address = listener_config.get("bind_address", "")
@ -178,13 +120,11 @@ class SynapseHomeServer(HomeServer):
if tls and config.no_tls: if tls and config.no_tls:
return return
metrics_resource = self.get_resource_for_metrics()
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
for name in res["names"]: for name in res["names"]:
if name == "client": if name == "client":
client_resource = self.get_client_resource() client_resource = ClientRestResource(self)
if res["compress"]: if res["compress"]:
client_resource = gz_wrap(client_resource) client_resource = gz_wrap(client_resource)
@ -198,31 +138,40 @@ class SynapseHomeServer(HomeServer):
if name == "federation": if name == "federation":
resources.update({ resources.update({
FEDERATION_PREFIX: self.get_resource_for_federation(), FEDERATION_PREFIX: TransportLayerServer(self),
}) })
if name in ["static", "client"]: if name in ["static", "client"]:
resources.update({ resources.update({
STATIC_PREFIX: self.get_resource_for_static_content(), STATIC_PREFIX: File(
os.path.join(os.path.dirname(synapse.__file__), "static")
),
}) })
if name in ["media", "federation", "client"]: if name in ["media", "federation", "client"]:
media_repo = MediaRepositoryResource(self)
resources.update({ resources.update({
MEDIA_PREFIX: self.get_resource_for_media_repository(), MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: self.get_resource_for_content_repo(), LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path, self.auth, self.content_addr
),
}) })
if name in ["keys", "federation"]: if name in ["keys", "federation"]:
resources.update({ resources.update({
SERVER_KEY_PREFIX: self.get_resource_for_server_key(), SERVER_KEY_PREFIX: LocalKey(self),
SERVER_KEY_V2_PREFIX: self.get_resource_for_server_key_v2(), SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
}) })
if name == "webclient": if name == "webclient":
resources[WEB_CLIENT_PREFIX] = self.get_resource_for_web_client() resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
if name == "metrics" and metrics_resource: if name == "metrics" and self.get_config().enable_metrics:
resources[METRICS_PREFIX] = metrics_resource resources[METRICS_PREFIX] = MetricsResource(self)
if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationResource(self)
root_resource = create_resource_tree(resources) root_resource = create_resource_tree(resources)
if tls: if tls:
@ -296,6 +245,18 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e: except IncorrectDatabaseSetup as e:
quit_with_error(e.message) quit_with_error(e.message)
def get_db_conn(self):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
self.database_engine.on_new_connection(db_conn)
return db_conn
def quit_with_error(error_string): def quit_with_error(error_string):
message_lines = error_string.split("\n") message_lines = error_string.split("\n")
@ -396,11 +357,20 @@ def setup(config_options):
Returns: Returns:
HomeServer HomeServer
""" """
try:
config = HomeServerConfig.load_config( config = HomeServerConfig.load_config(
"Synapse Homeserver", "Synapse Homeserver",
config_options, config_options,
generate_section="Homeserver" generate_section="Homeserver"
) )
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
if not config:
# If a config isn't returned, and an exception isn't raised, we're just
# generating config files and shouldn't try to continue.
sys.exit(0)
config.setup_logging() config.setup_logging()
@ -416,7 +386,7 @@ def setup(config_options):
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
database_engine = create_engine(config.database_config["name"]) database_engine = create_engine(config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
hs = SynapseHomeServer( hs = SynapseHomeServer(
@ -432,13 +402,7 @@ def setup(config_options):
logger.info("Preparing database: %s...", config.database_config['name']) logger.info("Preparing database: %s...", config.database_config['name'])
try: try:
db_conn = database_engine.module.connect( db_conn = hs.get_db_conn()
**{
k: v for k, v in config.database_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
database_engine.prepare_database(db_conn) database_engine.prepare_database(db_conn)
hs.run_startup_checks(db_conn, database_engine) hs.run_startup_checks(db_conn, database_engine)
@ -453,14 +417,18 @@ def setup(config_options):
logger.info("Database prepared in %s.", config.database_config['name']) logger.info("Database prepared in %s.", config.database_config['name'])
hs.setup()
hs.start_listening() hs.start_listening()
def start():
hs.get_pusherpool().start() hs.get_pusherpool().start()
hs.get_state_handler().start_caching() hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling() hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates() hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache() hs.get_replication_layer().start_get_pdu_cache()
reactor.callWhenRunning(start)
return hs return hs
@ -675,7 +643,7 @@ def _resource_id(resource, path_seg):
the mapping should looks like _resource_id(A,C) = B. the mapping should looks like _resource_id(A,C) = B.
Args: Args:
resource (Resource): The *parent* Resource resource (Resource): The *parent* Resourceb
path_seg (str): The name of the child Resource to be attached. path_seg (str): The name of the child Resource to be attached.
Returns: Returns:
str: A unique string which can be a key to the child Resource. str: A unique string which can be a key to the child Resource.
@ -722,8 +690,8 @@ def run(hs):
stats["uptime_seconds"] = uptime stats["uptime_seconds"] = uptime
stats["total_users"] = yield hs.get_datastore().count_all_users() stats["total_users"] = yield hs.get_datastore().count_all_users()
all_rooms = yield hs.get_datastore().get_rooms(False) room_count = yield hs.get_datastore().get_room_count()
stats["total_room_count"] = len(all_rooms) stats["total_room_count"] = room_count
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
daily_messages = yield hs.get_datastore().count_daily_messages() daily_messages = yield hs.get_datastore().count_daily_messages()
@ -745,6 +713,8 @@ def run(hs):
phone_home_task.start(60 * 60 * 24, now=False) phone_home_task.start(60 * 60 * 24, now=False)
def in_thread(): def in_thread():
# Uncomment to enable tracing of log context changes.
# sys.settrace(logcontext_tracer)
with LoggingContext("run"): with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit) change_resource_limit(hs.config.soft_file_limit)
reactor.run() reactor.run()
@ -752,7 +722,7 @@ def run(hs):
if hs.config.daemonize: if hs.config.daemonize:
if hs.config.print_pidfile: if hs.config.print_pidfile:
print hs.config.pid_file print (hs.config.pid_file)
daemon = Daemonize( daemon = Daemonize(
app="synapse-homeserver", app="synapse-homeserver",

View File

@ -29,13 +29,13 @@ NORMAL = "\x1b[m"
def start(configfile): def start(configfile):
print "Starting ...", print ("Starting ...")
args = SYNAPSE args = SYNAPSE
args.extend(["--daemonize", "-c", configfile]) args.extend(["--daemonize", "-c", configfile])
try: try:
subprocess.check_call(args) subprocess.check_call(args)
print GREEN + "started" + NORMAL print (GREEN + "started" + NORMAL)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print ( print (
RED + RED +
@ -48,7 +48,7 @@ def stop(pidfile):
if os.path.exists(pidfile): if os.path.exists(pidfile):
pid = int(open(pidfile).read()) pid = int(open(pidfile).read())
os.kill(pid, signal.SIGTERM) os.kill(pid, signal.SIGTERM)
print GREEN + "stopped" + NORMAL print (GREEN + "stopped" + NORMAL)
def main(): def main():

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.config._base import ConfigError
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
@ -21,9 +22,13 @@ if __name__ == "__main__":
if action == "read": if action == "read":
key = sys.argv[2] key = sys.argv[2]
try:
config = HomeServerConfig.load_config("", sys.argv[3:]) config = HomeServerConfig.load_config("", sys.argv[3:])
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
print getattr(config, key) print (getattr(config, key))
sys.exit(0) sys.exit(0)
else: else:
sys.stderr.write("Unknown command %r\n" % (action,)) sys.stderr.write("Unknown command %r\n" % (action,))

View File

@ -17,7 +17,6 @@ import argparse
import errno import errno
import os import os
import yaml import yaml
import sys
from textwrap import dedent from textwrap import dedent
@ -105,7 +104,7 @@ class Config(object):
dir_path = cls.abspath(dir_path) dir_path = cls.abspath(dir_path)
try: try:
os.makedirs(dir_path) os.makedirs(dir_path)
except OSError, e: except OSError as e:
if e.errno != errno.EEXIST: if e.errno != errno.EEXIST:
raise raise
if not os.path.isdir(dir_path): if not os.path.isdir(dir_path):
@ -136,13 +135,20 @@ class Config(object):
results.append(getattr(cls, name)(self, *args, **kargs)) results.append(getattr(cls, name)(self, *args, **kargs))
return results return results
def generate_config(self, config_dir_path, server_name, report_stats=None): def generate_config(
self,
config_dir_path,
server_name,
is_generating_file,
report_stats=None,
):
default_config = "# vim:ft=yaml\n" default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all( default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
"default_config", "default_config",
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
server_name=server_name, server_name=server_name,
is_generating_file=is_generating_file,
report_stats=report_stats, report_stats=report_stats,
)) ))
@ -244,8 +250,10 @@ class Config(object):
server_name = config_args.server_name server_name = config_args.server_name
if not server_name: if not server_name:
print "Must specify a server_name to a generate config for." raise ConfigError(
sys.exit(1) "Must specify a server_name to a generate config for."
" Pass -H server.name."
)
if not os.path.exists(config_dir_path): if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path) os.makedirs(config_dir_path)
with open(config_path, "wb") as config_file: with open(config_path, "wb") as config_file:
@ -253,6 +261,7 @@ class Config(object):
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
server_name=server_name, server_name=server_name,
report_stats=(config_args.report_stats == "yes"), report_stats=(config_args.report_stats == "yes"),
is_generating_file=True
) )
obj.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)
config_file.write(config_bytes) config_file.write(config_bytes)
@ -266,7 +275,7 @@ class Config(object):
"If this server name is incorrect, you will need to" "If this server name is incorrect, you will need to"
" regenerate the SSL certificates" " regenerate the SSL certificates"
) )
sys.exit(0) return
else: else:
print ( print (
"Config file %r already exists. Generating any missing key" "Config file %r already exists. Generating any missing key"
@ -302,25 +311,25 @@ class Config(object):
specified_config.update(yaml_config) specified_config.update(yaml_config)
if "server_name" not in specified_config: if "server_name" not in specified_config:
sys.stderr.write("\n" + MISSING_SERVER_NAME + "\n") raise ConfigError(MISSING_SERVER_NAME)
sys.exit(1)
server_name = specified_config["server_name"] server_name = specified_config["server_name"]
_, config = obj.generate_config( _, config = obj.generate_config(
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
server_name=server_name server_name=server_name,
is_generating_file=False,
) )
config.pop("log_config") config.pop("log_config")
config.update(specified_config) config.update(specified_config)
if "report_stats" not in config: if "report_stats" not in config:
sys.stderr.write( raise ConfigError(
"\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
MISSING_REPORT_STATS_SPIEL + "\n") MISSING_REPORT_STATS_SPIEL
sys.exit(1) )
if generate_keys: if generate_keys:
obj.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)
sys.exit(0) return
obj.invoke_all("read_config", config) obj.invoke_all("read_config", config)

40
synapse/config/api.py Normal file
View File

@ -0,0 +1,40 @@
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
from synapse.api.constants import EventTypes
class ApiConfig(Config):
def read_config(self, config):
self.room_invite_state_types = config.get("room_invite_state_types", [
EventTypes.JoinRules,
EventTypes.CanonicalAlias,
EventTypes.RoomAvatar,
EventTypes.Name,
])
def default_config(cls, **kwargs):
return """\
## API Configuration ##
# A list of event types that will be included in the room_invite_state
room_invite_state_types:
- "{JoinRules}"
- "{CanonicalAlias}"
- "{RoomAvatar}"
- "{Name}"
""".format(**vars(EventTypes))

View File

@ -23,6 +23,7 @@ from .captcha import CaptchaConfig
from .voip import VoipConfig from .voip import VoipConfig
from .registration import RegistrationConfig from .registration import RegistrationConfig
from .metrics import MetricsConfig from .metrics import MetricsConfig
from .api import ApiConfig
from .appservice import AppServiceConfig from .appservice import AppServiceConfig
from .key import KeyConfig from .key import KeyConfig
from .saml2 import SAML2Config from .saml2 import SAML2Config
@ -32,7 +33,7 @@ from .password import PasswordConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
PasswordConfig,): PasswordConfig,):
pass pass

View File

@ -22,8 +22,14 @@ from signedjson.key import (
read_signing_keys, write_signing_keys, NACL_ED25519 read_signing_keys, write_signing_keys, NACL_ED25519
) )
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
from synapse.util.stringutils import random_string_with_symbols
import os import os
import hashlib
import logging
logger = logging.getLogger(__name__)
class KeyConfig(Config): class KeyConfig(Config):
@ -40,9 +46,29 @@ class KeyConfig(Config):
config["perspectives"] config["perspectives"]
) )
def default_config(self, config_dir_path, server_name, **kwargs): self.macaroon_secret_key = config.get(
"macaroon_secret_key", self.registration_shared_secret
)
if not self.macaroon_secret_key:
# Unfortunately, there are people out there that don't have this
# set. Lets just be "nice" and derive one from their secret key.
logger.warn("Config is missing missing macaroon_secret_key")
seed = self.signing_key[0].seed
self.macaroon_secret_key = hashlib.sha256(seed)
def default_config(self, config_dir_path, server_name, is_generating_file=False,
**kwargs):
base_key_name = os.path.join(config_dir_path, server_name) base_key_name = os.path.join(config_dir_path, server_name)
if is_generating_file:
macaroon_secret_key = random_string_with_symbols(50)
else:
macaroon_secret_key = None
return """\ return """\
macaroon_secret_key: "%(macaroon_secret_key)s"
## Signing Keys ## ## Signing Keys ##
# Path to the signing key to sign messages with # Path to the signing key to sign messages with

View File

@ -23,22 +23,27 @@ from distutils.util import strtobool
class RegistrationConfig(Config): class RegistrationConfig(Config):
def read_config(self, config): def read_config(self, config):
self.disable_registration = not bool( self.enable_registration = bool(
strtobool(str(config["enable_registration"])) strtobool(str(config["enable_registration"]))
) )
if "disable_registration" in config: if "disable_registration" in config:
self.disable_registration = bool( self.enable_registration = not bool(
strtobool(str(config["disable_registration"])) strtobool(str(config["disable_registration"]))
) )
self.registration_shared_secret = config.get("registration_shared_secret") self.registration_shared_secret = config.get("registration_shared_secret")
self.macaroon_secret_key = config.get("macaroon_secret_key")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
self.allow_guest_access = config.get("allow_guest_access", False) self.allow_guest_access = config.get("allow_guest_access", False)
self.invite_3pid_guest = (
self.allow_guest_access and config.get("invite_3pid_guest", False)
)
def default_config(self, **kwargs): def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50) registration_shared_secret = random_string_with_symbols(50)
macaroon_secret_key = random_string_with_symbols(50)
return """\ return """\
## Registration ## ## Registration ##
@ -49,8 +54,6 @@ class RegistrationConfig(Config):
# secret, even if registration is otherwise disabled. # secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s" registration_shared_secret: "%(registration_shared_secret)s"
macaroon_secret_key: "%(macaroon_secret_key)s"
# Set the number of bcrypt rounds used to generate password hash. # Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash. # Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12. # The default number of rounds is 12.
@ -60,6 +63,12 @@ class RegistrationConfig(Config):
# participate in rooms hosted on this server which have been made # participate in rooms hosted on this server which have been made
# accessible to anonymous users. # accessible to anonymous users.
allow_guest_access: False allow_guest_access: False
# The list of identity servers trusted to verify third party
# identifiers by this server.
trusted_third_party_id_servers:
- matrix.org
- vector.im
""" % locals() """ % locals()
def add_arguments(self, parser): def add_arguments(self, parser):
@ -71,6 +80,6 @@ class RegistrationConfig(Config):
def read_arguments(self, args): def read_arguments(self, args):
if args.enable_registration is not None: if args.enable_registration is not None:
self.disable_registration = not bool( self.enable_registration = bool(
strtobool(str(args.enable_registration)) strtobool(str(args.enable_registration))
) )

View File

@ -101,4 +101,7 @@ class ContentRepositoryConfig(Config):
- width: 640 - width: 640
height: 480 height: 480
method: scale method: scale
- width: 800
height: 600
method: scale
""" % locals() """ % locals()

View File

@ -36,6 +36,7 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
factory = SynapseKeyClientFactory() factory = SynapseKeyClientFactory()
factory.path = path factory.path = path
factory.host = server_name
endpoint = matrix_federation_endpoint( endpoint = matrix_federation_endpoint(
reactor, server_name, ssl_context_factory, timeout=30 reactor, server_name, ssl_context_factory, timeout=30
) )
@ -81,6 +82,8 @@ class SynapseKeyClientProtocol(HTTPClient):
self.host = self.transport.getHost() self.host = self.transport.getHost()
logger.debug("Connected to %s", self.host) logger.debug("Connected to %s", self.host)
self.sendCommand(b"GET", self.path) self.sendCommand(b"GET", self.path)
if self.host:
self.sendHeader(b"Host", self.host)
self.endHeaders() self.endHeaders()
self.timer = reactor.callLater( self.timer = reactor.callLater(
self.timeout, self.timeout,

View File

@ -18,6 +18,10 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import (
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
preserve_fn
)
from twisted.internet import defer from twisted.internet import defer
@ -142,6 +146,8 @@ class Keyring(object):
for server_name, _ in server_and_json for server_name, _ in server_and_json
} }
with PreserveLoggingContext():
# We want to wait for any previous lookups to complete before # We want to wait for any previous lookups to complete before
# proceeding. # proceeding.
wait_on_deferred = self.wait_for_previous_lookups( wait_on_deferred = self.wait_for_previous_lookups(
@ -175,7 +181,8 @@ class Keyring(object):
# Pass those keys to handle_key_deferred so that the json object # Pass those keys to handle_key_deferred so that the json object
# signatures can be verified # signatures can be verified
return [ return [
handle_key_deferred( preserve_context_over_fn(
handle_key_deferred,
group_id_to_group[g_id], group_id_to_group[g_id],
deferreds[g_id], deferreds[g_id],
) )
@ -198,12 +205,13 @@ class Keyring(object):
if server_name in self.key_downloads if server_name in self.key_downloads
] ]
if wait_on: if wait_on:
with PreserveLoggingContext():
yield defer.DeferredList(wait_on) yield defer.DeferredList(wait_on)
else: else:
break break
for server_name, deferred in server_to_deferred.items(): for server_name, deferred in server_to_deferred.items():
d = ObservableDeferred(deferred) d = ObservableDeferred(preserve_context_over_deferred(deferred))
self.key_downloads[server_name] = d self.key_downloads[server_name] = d
def rm(r, server_name): def rm(r, server_name):
@ -244,6 +252,7 @@ class Keyring(object):
for group in group_id_to_group.values(): for group in group_id_to_group.values():
for key_id in group.key_ids: for key_id in group.key_ids:
if key_id in merged_results[group.server_name]: if key_id in merged_results[group.server_name]:
with PreserveLoggingContext():
group_id_to_deferred[group.group_id].callback(( group_id_to_deferred[group.group_id].callback((
group.group_id, group.group_id,
group.server_name, group.server_name,
@ -504,7 +513,7 @@ class Keyring(object):
yield defer.gatherResults( yield defer.gatherResults(
[ [
self.store_keys( preserve_fn(self.store_keys)(
server_name=key_server_name, server_name=key_server_name,
from_server=server_name, from_server=server_name,
verify_keys=verify_keys, verify_keys=verify_keys,
@ -573,7 +582,7 @@ class Keyring(object):
yield defer.gatherResults( yield defer.gatherResults(
[ [
self.store.store_server_keys_json( preserve_fn(self.store.store_server_keys_json)(
server_name=server_name, server_name=server_name,
key_id=key_id, key_id=key_id,
from_server=server_name, from_server=server_name,
@ -675,7 +684,7 @@ class Keyring(object):
# TODO(markjh): Store whether the keys have expired. # TODO(markjh): Store whether the keys have expired.
yield defer.gatherResults( yield defer.gatherResults(
[ [
self.store.store_server_verify_key( preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key server_name, server_name, key.time_added, key
) )
for key_id, key in verify_keys.items() for key_id, key in verify_keys.items()

View File

@ -168,5 +168,7 @@ class FrozenEvent(EventBase):
def __repr__(self): def __repr__(self):
return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % ( return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % (
self.event_id, self.type, self.get("state_key", None), self.get("event_id", None),
self.get("type", None),
self.get("state_key", None),
) )

View File

@ -20,3 +20,4 @@ class EventContext(object):
self.current_state = current_state self.current_state = current_state
self.state_group = None self.state_group = None
self.rejected = False self.rejected = False
self.push_actions = []

View File

@ -17,15 +17,10 @@
""" """
from .replication import ReplicationLayer from .replication import ReplicationLayer
from .transport import TransportLayer from .transport.client import TransportLayerClient
def initialize_http_replication(homeserver): def initialize_http_replication(homeserver):
transport = TransportLayer( transport = TransportLayerClient(homeserver)
homeserver,
homeserver.hostname,
server=homeserver.get_resource_for_federation(),
client=homeserver.get_http_client()
)
return ReplicationLayer(homeserver, transport) return ReplicationLayer(homeserver, transport)

View File

@ -57,7 +57,7 @@ class FederationClient(FederationBase):
cache_name="get_pdu_cache", cache_name="get_pdu_cache",
clock=self._clock, clock=self._clock,
max_len=1000, max_len=1000,
expiry_ms=120*1000, expiry_ms=120 * 1000,
reset_expiry_on_get=False, reset_expiry_on_get=False,
) )
@ -114,7 +114,7 @@ class FederationClient(FederationBase):
@log_function @log_function
def make_query(self, destination, query_type, args, def make_query(self, destination, query_type, args,
retry_on_dns_fail=True): retry_on_dns_fail=False):
"""Sends a federation Query to a remote homeserver of the given type """Sends a federation Query to a remote homeserver of the given type
and arguments. and arguments.

View File

@ -126,10 +126,8 @@ class FederationServer(FederationBase):
results = [] results = []
for pdu in pdu_list: for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu)
try: try:
yield d yield self._handle_new_pdu(transaction.origin, pdu)
results.append({}) results.append({})
except FederationError as e: except FederationError as e:
self.send_failure(e, transaction.origin) self.send_failure(e, transaction.origin)
@ -139,8 +137,8 @@ class FederationServer(FederationBase):
logger.exception("Failed to handle PDU") logger.exception("Failed to handle PDU")
if hasattr(transaction, "edus"): if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]: for edu in (Edu(**x) for x in transaction.edus):
self.received_edu( yield self.received_edu(
transaction.origin, transaction.origin,
edu.edu_type, edu.edu_type,
edu.content edu.content
@ -163,11 +161,17 @@ class FederationServer(FederationBase):
) )
defer.returnValue((200, response)) defer.returnValue((200, response))
@defer.inlineCallbacks
def received_edu(self, origin, edu_type, content): def received_edu(self, origin, edu_type, content):
received_edus_counter.inc() received_edus_counter.inc()
if edu_type in self.edu_handlers: if edu_type in self.edu_handlers:
self.edu_handlers[edu_type](origin, content) try:
yield self.edu_handlers[edu_type](origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e:
logger.exception("Failed to handle edu %r", edu_type, e)
else: else:
logger.warn("Received EDU of type %s with no handler", edu_type) logger.warn("Received EDU of type %s with no handler", edu_type)
@ -545,8 +549,19 @@ class FederationServer(FederationBase):
return event return event
@defer.inlineCallbacks @defer.inlineCallbacks
def exchange_third_party_invite(self, invite): def exchange_third_party_invite(
ret = yield self.handler.exchange_third_party_invite(invite) self,
sender_user_id,
target_user_id,
room_id,
signed,
):
ret = yield self.handler.exchange_third_party_invite(
sender_user_id,
target_user_id,
room_id,
signed,
)
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -54,8 +54,6 @@ class ReplicationLayer(FederationClient, FederationServer):
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self.transport_layer = transport_layer self.transport_layer = transport_layer
self.transport_layer.register_received_handler(self)
self.transport_layer.register_request_handler(self)
self.federation_client = self self.federation_client = self

View File

@ -103,7 +103,6 @@ class TransactionQueue(object):
else: else:
return not destination.startswith("localhost") return not destination.startswith("localhost")
@defer.inlineCallbacks
def enqueue_pdu(self, pdu, destinations, order): def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have # We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus # a transaction in progress. If we do, stick it in the pending_pdus
@ -141,8 +140,6 @@ class TransactionQueue(object):
deferreds.append(deferred) deferreds.append(deferred)
yield defer.DeferredList(deferreds, consumeErrors=True)
# NO inlineCallbacks # NO inlineCallbacks
def enqueue_edu(self, edu): def enqueue_edu(self, edu):
destination = edu.destination destination = edu.destination

View File

@ -20,55 +20,3 @@ By default this is done over HTTPS (and all home servers are required to
support HTTPS), however individual pairings of servers may decide to support HTTPS), however individual pairings of servers may decide to
communicate over a different (albeit still reliable) protocol. communicate over a different (albeit still reliable) protocol.
""" """
from .server import TransportLayerServer
from .client import TransportLayerClient
from synapse.util.ratelimitutils import FederationRateLimiter
class TransportLayer(TransportLayerServer, TransportLayerClient):
"""This is a basic implementation of the transport layer that translates
transactions and other requests to/from HTTP.
Attributes:
server_name (str): Local home server host
server (synapse.http.server.HttpServer): the http server to
register listeners on
client (synapse.http.client.HttpClient): the http client used to
send requests
request_handler (TransportRequestHandler): The handler to fire when we
receive requests for data.
received_handler (TransportReceivedHandler): The handler to fire when
we receive data.
"""
def __init__(self, homeserver, server_name, server, client):
"""
Args:
server_name (str): Local home server host
server (synapse.protocol.http.HttpServer): the http server to
register listeners on
client (synapse.protocol.http.HttpClient): the http client used to
send requests
"""
self.keyring = homeserver.get_keyring()
self.clock = homeserver.get_clock()
self.server_name = server_name
self.server = server
self.client = client
self.request_handler = None
self.received_handler = None
self.ratelimiter = FederationRateLimiter(
self.clock,
window_size=homeserver.config.federation_rc_window_size,
sleep_limit=homeserver.config.federation_rc_sleep_limit,
sleep_msec=homeserver.config.federation_rc_sleep_delay,
reject_limit=homeserver.config.federation_rc_reject_limit,
concurrent_requests=homeserver.config.federation_rc_concurrent,
)

View File

@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
class TransportLayerClient(object): class TransportLayerClient(object):
"""Sends federation HTTP requests to other servers""" """Sends federation HTTP requests to other servers"""
def __init__(self, hs):
self.server_name = hs.hostname
self.client = hs.get_http_client()
@log_function @log_function
def get_room_state(self, destination, room_id, event_id): def get_room_state(self, destination, room_id, event_id):
""" Requests all state for a given room from the given server at the """ Requests all state for a given room from the given server at the
@ -156,6 +160,7 @@ class TransportLayerClient(object):
path=path, path=path,
args=args, args=args,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
timeout=10000,
) )
defer.returnValue(content) defer.returnValue(content)

View File

@ -17,7 +17,9 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.util.logutils import log_function from synapse.http.server import JsonResource
from synapse.http.servlet import parse_json_object_from_request
from synapse.util.ratelimitutils import FederationRateLimiter
import functools import functools
import logging import logging
@ -28,9 +30,41 @@ import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TransportLayerServer(object): class TransportLayerServer(JsonResource):
"""Handles incoming federation HTTP requests""" """Handles incoming federation HTTP requests"""
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
super(TransportLayerServer, self).__init__(hs)
self.authenticator = Authenticator(hs)
self.ratelimiter = FederationRateLimiter(
self.clock,
window_size=hs.config.federation_rc_window_size,
sleep_limit=hs.config.federation_rc_sleep_limit,
sleep_msec=hs.config.federation_rc_sleep_delay,
reject_limit=hs.config.federation_rc_reject_limit,
concurrent_requests=hs.config.federation_rc_concurrent,
)
self.register_servlets()
def register_servlets(self):
register_servlets(
self.hs,
resource=self,
ratelimiter=self.ratelimiter,
authenticator=self.authenticator,
)
class Authenticator(object):
def __init__(self, hs):
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks @defer.inlineCallbacks
def authenticate_request(self, request): def authenticate_request(self, request):
@ -98,37 +132,9 @@ class TransportLayerServer(object):
defer.returnValue((origin, content)) defer.returnValue((origin, content))
@log_function
def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data.
Args:
handler (TransportReceivedHandler)
"""
FederationSendServlet(
handler,
authenticator=self,
ratelimiter=self.ratelimiter,
server_name=self.server_name,
).register(self.server)
@log_function
def register_request_handler(self, handler):
""" Register a handler that will be fired when we get asked for data.
Args:
handler (TransportRequestHandler)
"""
for servletclass in SERVLET_CLASSES:
servletclass(
handler,
authenticator=self,
ratelimiter=self.ratelimiter,
).register(self.server)
class BaseFederationServlet(object): class BaseFederationServlet(object):
def __init__(self, handler, authenticator, ratelimiter): def __init__(self, handler, authenticator, ratelimiter, server_name):
self.handler = handler self.handler = handler
self.authenticator = authenticator self.authenticator = authenticator
self.ratelimiter = ratelimiter self.ratelimiter = ratelimiter
@ -172,7 +178,9 @@ class FederationSendServlet(BaseFederationServlet):
PATH = "/send/([^/]*)/" PATH = "/send/([^/]*)/"
def __init__(self, handler, server_name, **kwargs): def __init__(self, handler, server_name, **kwargs):
super(FederationSendServlet, self).__init__(handler, **kwargs) super(FederationSendServlet, self).__init__(
handler, server_name=server_name, **kwargs
)
self.server_name = server_name self.server_name = server_name
# This is when someone is trying to send us a bunch of data. # This is when someone is trying to send us a bunch of data.
@ -412,13 +420,22 @@ class On3pidBindServlet(BaseFederationServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
content_bytes = request.content.read() content = parse_json_object_from_request(request)
content = json.loads(content_bytes)
if "invites" in content: if "invites" in content:
last_exception = None last_exception = None
for invite in content["invites"]: for invite in content["invites"]:
try: try:
yield self.handler.exchange_third_party_invite(invite) if "signed" not in invite or "token" not in invite["signed"]:
message = ("Rejecting received notification of third-"
"party invite without signed: %s" % (invite,))
logger.info(message)
raise SynapseError(400, message)
yield self.handler.exchange_third_party_invite(
invite["sender"],
invite["mxid"],
invite["room_id"],
invite["signed"],
)
except Exception as e: except Exception as e:
last_exception = e last_exception = e
if last_exception: if last_exception:
@ -432,6 +449,7 @@ class On3pidBindServlet(BaseFederationServlet):
SERVLET_CLASSES = ( SERVLET_CLASSES = (
FederationSendServlet,
FederationPullServlet, FederationPullServlet,
FederationEventServlet, FederationEventServlet,
FederationStateServlet, FederationStateServlet,
@ -451,3 +469,13 @@ SERVLET_CLASSES = (
FederationThirdPartyInviteExchangeServlet, FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet, On3pidBindServlet,
) )
def register_servlets(hs, resource, authenticator, ratelimiter):
for servletclass in SERVLET_CLASSES:
servletclass(
handler=hs.get_replication_layer(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError, SynapseError, AuthError from synapse.api.errors import LimitExceededError, SynapseError, AuthError
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, RoomAlias from synapse.types import UserID, RoomAlias, Requester
from synapse.push.action_generator import ActionGenerator from synapse.push.action_generator import ActionGenerator
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
@ -29,6 +29,14 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VISIBILITY_PRIORITY = (
"world_readable",
"shared",
"invited",
"joined",
)
class BaseHandler(object): class BaseHandler(object):
""" """
Common base class for the event handlers. Common base class for the event handlers.
@ -53,25 +61,16 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_clients(self, user_tuples, events): def filter_events_for_clients(self, user_tuples, events, event_id_to_state):
""" Returns dict of user_id -> list of events that user is allowed to """ Returns dict of user_id -> list of events that user is allowed to
see. see.
:param (str, bool) user_tuples: (user id, is_peeking) for each
user to be checked. is_peeking should be true if:
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the given
events
""" """
# If there is only one user, just get the state for that one user,
# otherwise just get all the state.
if len(user_tuples) == 1:
types = (
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_tuples[0][0]),
)
else:
types = None
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=types
)
forgotten = yield defer.gatherResults([ forgotten = yield defer.gatherResults([
self.store.who_forgot_in_room( self.store.who_forgot_in_room(
room_id, room_id,
@ -87,18 +86,38 @@ class BaseHandler(object):
def allowed(event, user_id, is_peeking): def allowed(event, user_id, is_peeking):
state = event_id_to_state[event.event_id] state = event_id_to_state[event.event_id]
# get the room_visibility at the time of the event.
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event: if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared") visibility = visibility_event.content.get("history_visibility", "shared")
else: else:
visibility = "shared" visibility = "shared"
if visibility not in VISIBILITY_PRIORITY:
visibility = "shared"
# if it was world_readable, it's easy: everyone can read it
if visibility == "world_readable": if visibility == "world_readable":
return True return True
if is_peeking: # Always allow history visibility events on boundaries. This is done
return False # by setting the effective visibility to the least restrictive
# of the old vs new.
if event.type == EventTypes.RoomHistoryVisibility:
prev_content = event.unsigned.get("prev_content", {})
prev_visibility = prev_content.get("history_visibility", None)
if prev_visibility not in VISIBILITY_PRIORITY:
prev_visibility = "shared"
new_priority = VISIBILITY_PRIORITY.index(visibility)
old_priority = VISIBILITY_PRIORITY.index(prev_visibility)
if old_priority < new_priority:
visibility = prev_visibility
# get the user's membership at the time of the event. (or rather,
# just *after* the event. Which means that people can see their
# own join events, but not (currently) their own leave events.)
membership_event = state.get((EventTypes.Member, user_id), None) membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event: if membership_event:
if membership_event.event_id in event_id_forgotten: if membership_event.event_id in event_id_forgotten:
@ -108,20 +127,29 @@ class BaseHandler(object):
else: else:
membership = None membership = None
# if the user was a member of the room at the time of the event,
# they can see it.
if membership == Membership.JOIN: if membership == Membership.JOIN:
return True return True
if event.type == EventTypes.RoomHistoryVisibility: if visibility == "joined":
return not is_peeking # we weren't a member at the time of the event, so we can't
# see this event.
return False
if visibility == "shared":
return True
elif visibility == "joined":
return membership == Membership.JOIN
elif visibility == "invited": elif visibility == "invited":
# user can also see the event if they were *invited* at the time
# of the event.
return membership == Membership.INVITE return membership == Membership.INVITE
return True else:
# visibility is shared: user can also see the event if they have
# become a member since the event
#
# XXX: if the user has subsequently joined and then left again,
# ideally we would share history up to the point they left. But
# we don't know when they left.
return not is_peeking
defer.returnValue({ defer.returnValue({
user_id: [ user_id: [
@ -134,25 +162,45 @@ class BaseHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_peeking=False): def _filter_events_for_client(self, user_id, events, is_peeking=False):
# Assumes that user has at some point joined the room if not is_guest. """
res = yield self._filter_events_for_clients([(user_id, is_peeking)], events) Check which events a user is allowed to see
:param str user_id: user id to be checked
:param [synapse.events.EventBase] events: list of events to be checked
:param bool is_peeking should be True if:
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the given
events
:rtype [synapse.events.EventBase]
"""
types = (
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=types
)
res = yield self.filter_events_for_clients(
[(user_id, is_peeking)], events, event_id_to_state
)
defer.returnValue(res.get(user_id, [])) defer.returnValue(res.get(user_id, []))
def ratelimit(self, user_id): def ratelimit(self, requester):
time_now = self.clock.time() time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message( allowed, time_allowed = self.ratelimiter.send_message(
user_id, time_now, requester.user.to_string(), time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second, msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_message_burst_count, burst_count=self.hs.config.rc_message_burst_count,
) )
if not allowed: if not allowed:
raise LimitExceededError( raise LimitExceededError(
retry_after_ms=int(1000*(time_allowed - time_now)), retry_after_ms=int(1000 * (time_allowed - time_now)),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder): def _create_new_client_event(self, builder):
latest_ret = yield self.store.get_latest_events_in_room( latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
builder.room_id, builder.room_id,
) )
@ -161,7 +209,10 @@ class BaseHandler(object):
else: else:
depth = 1 depth = 1
prev_events = [(e, h) for e, h, _ in latest_ret] prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in latest_ret
]
builder.prev_events = prev_events builder.prev_events = prev_events
builder.depth = depth builder.depth = depth
@ -170,6 +221,50 @@ class BaseHandler(object):
context = yield state_handler.compute_event_context(builder) context = yield state_handler.compute_event_context(builder)
# If we've received an invite over federation, there are no latest
# events in the room, because we don't know enough about the graph
# fragment we received to treat it like a graph, so the above returned
# no relevant events. It may have returned some events (if we have
# joined and left the room), but not useful ones, like the invite.
if (
not self.is_host_in_room(context.current_state) and
builder.type == EventTypes.Member
):
prev_member_event = yield self.store.get_room_member(
builder.sender, builder.room_id
)
# The prev_member_event may already be in context.current_state,
# despite us not being present in the room; in particular, if
# inviting user, and all other local users, have already left.
#
# In that case, we have all the information we need, and we don't
# want to drop "context" - not least because we may need to handle
# the invite locally, which will require us to have the whole
# context (not just prev_member_event) to auth it.
#
context_event_ids = (
e.event_id for e in context.current_state.values()
)
if (
prev_member_event and
prev_member_event.event_id not in context_event_ids
):
# The prev_member_event is missing from context, so it must
# have arrived over federation and is an outlier. We forcibly
# set our context to the invite we received over federation
builder.prev_events = (
prev_member_event.event_id,
prev_member_event.prev_events
)
context = yield state_handler.compute_event_context(
builder,
old_state=(prev_member_event,),
outlier=True
)
if builder.is_state(): if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes( builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events context.prev_state_events
@ -192,10 +287,40 @@ class BaseHandler(object):
(event, context,) (event, context,)
) )
def is_host_in_room(self, current_state):
room_members = [
(state_key, event.membership)
for ((event_type, state_key), event) in current_state.items()
if event_type == EventTypes.Member
]
if len(room_members) == 0:
# Have we just created the room, and is this about to be the very
# first member event?
create_event = current_state.get(("m.room.create", ""))
if create_event:
return True
for (state_key, membership) in room_members:
if (
UserID.from_string(state_key).domain == self.hs.hostname
and membership == Membership.JOIN
):
return True
return False
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_new_client_event(self, event, context, extra_users=[]): def handle_new_client_event(
self,
requester,
event,
context,
ratelimit=True,
extra_users=[]
):
# We now need to go and hit out to wherever we need to hit out to. # We now need to go and hit out to wherever we need to hit out to.
if ratelimit:
self.ratelimit(requester)
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
yield self.maybe_kick_guest_users(event, context.current_state.values()) yield self.maybe_kick_guest_users(event, context.current_state.values())
@ -220,6 +345,12 @@ class BaseHandler(object):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE: if event.content["membership"] == Membership.INVITE:
def is_inviter_member_event(e):
return (
e.type == EventTypes.Member and
e.sender == event.sender
)
event.unsigned["invite_room_state"] = [ event.unsigned["invite_room_state"] = [
{ {
"type": e.type, "type": e.type,
@ -228,12 +359,8 @@ class BaseHandler(object):
"sender": e.sender, "sender": e.sender,
} }
for k, e in context.current_state.items() for k, e in context.current_state.items()
if e.type in ( if e.type in self.hs.config.room_invite_state_types
EventTypes.JoinRules, or is_inviter_member_event(e)
EventTypes.CanonicalAlias,
EventTypes.RoomAvatar,
EventTypes.Name,
)
] ]
invitee = UserID.from_string(event.state_key) invitee = UserID.from_string(event.state_key)
@ -269,13 +396,19 @@ class BaseHandler(object):
"You don't have permission to redact events" "You don't have permission to redact events"
) )
(event_stream_id, max_stream_id) = yield self.store.persist_event( if event.type == EventTypes.Create and context.current_state:
event, context=context raise AuthError(
403,
"Changing the room create event is forbidden",
) )
action_generator = ActionGenerator(self.hs) action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event( yield action_generator.handle_push_actions_for_event(
event, self event, context, self
)
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
) )
destinations = set() destinations = set()
@ -293,19 +426,11 @@ class BaseHandler(object):
with PreserveLoggingContext(): with PreserveLoggingContext():
# Don't block waiting on waking up all the listeners. # Don't block waiting on waking up all the listeners.
notify_d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=extra_users extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
notify_d.addErrback(log_failure)
# If invite, remove room_state from unsigned before sending. # If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None) event.unsigned.pop("invite_room_state", None)
@ -329,7 +454,8 @@ class BaseHandler(object):
if member_event.type != EventTypes.Member: if member_event.type != EventTypes.Member:
continue continue
if not self.hs.is_mine(UserID.from_string(member_event.state_key)): target_user = UserID.from_string(member_event.state_key)
if not self.hs.is_mine(target_user):
continue continue
if member_event.content["membership"] not in { if member_event.content["membership"] not in {
@ -351,18 +477,13 @@ class BaseHandler(object):
# and having homeservers have their own users leave keeps more # and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having # of that decision-making and control local to the guest-having
# homeserver. # homeserver.
message_handler = self.hs.get_handlers().message_handler requester = Requester(target_user, "", True)
yield message_handler.create_and_send_event( handler = self.hs.get_handlers().room_member_handler
{ yield handler.update_membership(
"type": EventTypes.Member, requester,
"state_key": member_event.state_key, target_user,
"content": { member_event.room_id,
"membership": Membership.LEAVE, "leave",
"kind": "guest"
},
"room_id": member_event.room_id,
"sender": member_event.state_key
},
ratelimit=False, ratelimit=False,
) )
except Exception as e: except Exception as e:

View File

@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
class AuthHandler(BaseHandler): class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs): def __init__(self, hs):
super(AuthHandler, self).__init__(hs) super(AuthHandler, self).__init__(hs)
@ -66,15 +67,18 @@ class AuthHandler(BaseHandler):
'auth' key: this method prompts for auth if none is sent. 'auth' key: this method prompts for auth if none is sent.
clientip (str): The IP address of the client. clientip (str): The IP address of the client.
Returns: Returns:
A tuple of (authed, dict, dict) where authed is true if the client A tuple of (authed, dict, dict, session_id) where authed is true if
has successfully completed an auth flow. If it is true, the first the client has successfully completed an auth flow. If it is true
dict contains the authenticated credentials of each stage. the first dict contains the authenticated credentials of each stage.
If authed is false, the first dictionary is the server response to If authed is false, the first dictionary is the server response to
the login request and should be passed back to the client. the login request and should be passed back to the client.
In either case, the second dict contains the parameters for this In either case, the second dict contains the parameters for this
request (which may have been given only in a previous call). request (which may have been given only in a previous call).
session_id is the ID of this session, either passed in by the client
or assigned by the call to check_auth
""" """
authdict = None authdict = None
@ -103,7 +107,10 @@ class AuthHandler(BaseHandler):
if not authdict: if not authdict:
defer.returnValue( defer.returnValue(
(False, self._auth_dict_for_flows(flows, session), clientdict) (
False, self._auth_dict_for_flows(flows, session),
clientdict, session['id']
)
) )
if 'creds' not in session: if 'creds' not in session:
@ -122,12 +129,11 @@ class AuthHandler(BaseHandler):
for f in flows: for f in flows:
if len(set(f) - set(creds.keys())) == 0: if len(set(f) - set(creds.keys())) == 0:
logger.info("Auth completed with creds: %r", creds) logger.info("Auth completed with creds: %r", creds)
self._remove_session(session) defer.returnValue((True, creds, clientdict, session['id']))
defer.returnValue((True, creds, clientdict))
ret = self._auth_dict_for_flows(flows, session) ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys() ret['completed'] = creds.keys()
defer.returnValue((False, ret, clientdict)) defer.returnValue((False, ret, clientdict, session['id']))
@defer.inlineCallbacks @defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip): def add_oob_auth(self, stagetype, authdict, clientip):
@ -154,6 +160,43 @@ class AuthHandler(BaseHandler):
defer.returnValue(True) defer.returnValue(True)
defer.returnValue(False) defer.returnValue(False)
def get_session_id(self, clientdict):
"""
Gets the session ID for a client given the client dictionary
:param clientdict: The dictionary sent by the client in the request
:return: The string session ID the client sent. If the client did not
send a session ID, returns None.
"""
sid = None
if clientdict and 'auth' in clientdict:
authdict = clientdict['auth']
if 'session' in authdict:
sid = authdict['session']
return sid
def set_session_data(self, session_id, key, value):
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
the client.
:param session_id: (string) The ID of this session as returned from check_auth
:param key: (string) The key to store the data under
:param value: (any) The data to store
"""
sess = self._get_session_info(session_id)
sess.setdefault('serverdict', {})[key] = value
self._save_session(sess)
def get_session_data(self, session_id, key, default=None):
"""
Retrieve data stored with set_session_data
:param session_id: (string) The ID of this session as returned from check_auth
:param key: (string) The key to store the data under
:param default: (any) Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password_auth(self, authdict, _): def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict: if "user" not in authdict or "password" not in authdict:
@ -432,13 +475,18 @@ class AuthHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword): def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword) password_hash = self.hash(newpassword)
except_access_token_ids = [requester.access_token_id] if requester else []
yield self.store.user_set_password_hash(user_id, password_hash) yield self.store.user_set_password_hash(user_id, password_hash)
yield self.store.user_delete_access_tokens(user_id) yield self.store.user_delete_access_tokens(
yield self.hs.get_pusherpool().remove_pushers_by_user(user_id) user_id, except_access_token_ids
yield self.store.flush_user(user_id) )
yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_ids
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at): def add_threepid(self, user_id, medium, address, validated_at):
@ -450,11 +498,18 @@ class AuthHandler(BaseHandler):
def _save_session(self, session): def _save_session(self, session):
# TODO: Persistent storage # TODO: Persistent storage
logger.debug("Saving session %s", session) logger.debug("Saving session %s", session)
session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session self.sessions[session["id"]] = session
self._prune_sessions()
def _remove_session(self, session): def _prune_sessions(self):
logger.debug("Removing session %s", session) for sid, sess in self.sessions.items():
del self.sessions[session["id"]] last_used = 0
if 'last_used' in sess:
last_used = sess['last_used']
now = self.hs.get_clock().time_msec()
if last_used < now - AuthHandler.SESSION_EXPIRE_MS:
del self.sessions[sid]
def hash(self, password): def hash(self, password):
"""Computes a secure hash of password. """Computes a secure hash of password.
@ -477,4 +532,4 @@ class AuthHandler(BaseHandler):
Returns: Returns:
Whether self.hash(password) == stored_hash (bool). Whether self.hash(password) == stored_hash (bool).
""" """
return bcrypt.checkpw(password, stored_hash) return bcrypt.hashpw(password, stored_hash) == stored_hash

View File

@ -17,9 +17,9 @@
from twisted.internet import defer from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import SynapseError, Codes, CodeMessageException from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomAlias from synapse.types import RoomAlias, UserID
import logging import logging
import string import string
@ -32,13 +32,15 @@ class DirectoryHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(DirectoryHandler, self).__init__(hs) super(DirectoryHandler, self).__init__(hs)
self.state = hs.get_state_handler()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.federation.register_query_handler( self.federation.register_query_handler(
"directory", self.on_directory_query "directory", self.on_directory_query
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_association(self, room_alias, room_id, servers=None): def _create_association(self, room_alias, room_id, servers=None, creator=None):
# general association creation for both human users and app services # general association creation for both human users and app services
for wchar in string.whitespace: for wchar in string.whitespace:
@ -60,7 +62,8 @@ class DirectoryHandler(BaseHandler):
yield self.store.create_room_alias_association( yield self.store.create_room_alias_association(
room_alias, room_alias,
room_id, room_id,
servers servers,
creator=creator,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -77,7 +80,7 @@ class DirectoryHandler(BaseHandler):
400, "This alias is reserved by an application service.", 400, "This alias is reserved by an application service.",
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
yield self._create_association(room_alias, room_id, servers) yield self._create_association(room_alias, room_id, servers, creator=user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def create_appservice_association(self, service, room_alias, room_id, def create_appservice_association(self, service, room_alias, room_id,
@ -92,10 +95,14 @@ class DirectoryHandler(BaseHandler):
yield self._create_association(room_alias, room_id, servers) yield self._create_association(room_alias, room_id, servers)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_association(self, user_id, room_alias): def delete_association(self, requester, user_id, room_alias):
# association deletion for human users # association deletion for human users
# TODO Check if server admin can_delete = yield self._user_can_delete_alias(room_alias, user_id)
if not can_delete:
raise AuthError(
403, "You don't have permission to delete the alias.",
)
can_delete = yield self.can_modify_alias( can_delete = yield self.can_modify_alias(
room_alias, room_alias,
@ -107,7 +114,25 @@ class DirectoryHandler(BaseHandler):
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
yield self._delete_association(room_alias) room_id = yield self._delete_association(room_alias)
try:
yield self.send_room_alias_update_event(
requester,
requester.user.to_string(),
room_id
)
yield self._update_canonical_alias(
requester,
requester.user.to_string(),
room_id,
room_alias,
)
except AuthError as e:
logger.info("Failed to update alias events: %s", e)
defer.returnValue(room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_appservice_association(self, service, room_alias): def delete_appservice_association(self, service, room_alias):
@ -124,11 +149,9 @@ class DirectoryHandler(BaseHandler):
if not self.hs.is_mine(room_alias): if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local") raise SynapseError(400, "Room alias must be local")
yield self.store.delete_room_alias(room_alias) room_id = yield self.store.delete_room_alias(room_alias)
# TODO - Looks like _update_room_alias_event has never been implemented defer.returnValue(room_id)
# if room_id:
# yield self._update_room_alias_events(user_id, room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_association(self, room_alias): def get_association(self, room_alias):
@ -175,8 +198,8 @@ class DirectoryHandler(BaseHandler):
# If this server is in the list of servers, return it first. # If this server is in the list of servers, return it first.
if self.server_name in servers: if self.server_name in servers:
servers = ( servers = (
[self.server_name] [self.server_name] +
+ [s for s in servers if s != self.server_name] [s for s in servers if s != self.server_name]
) )
else: else:
servers = list(servers) servers = list(servers)
@ -212,17 +235,44 @@ class DirectoryHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def send_room_alias_update_event(self, user_id, room_id): def send_room_alias_update_event(self, requester, user_id, room_id):
aliases = yield self.store.get_aliases_for_room(room_id) aliases = yield self.store.get_aliases_for_room(room_id)
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_event({ yield msg_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Aliases, "type": EventTypes.Aliases,
"state_key": self.hs.hostname, "state_key": self.hs.hostname,
"room_id": room_id, "room_id": room_id,
"sender": user_id, "sender": user_id,
"content": {"aliases": aliases}, "content": {"aliases": aliases},
}, ratelimit=False) },
ratelimit=False
)
@defer.inlineCallbacks
def _update_canonical_alias(self, requester, user_id, room_id, room_alias):
alias_event = yield self.state.get_current_state(
room_id, EventTypes.CanonicalAlias, ""
)
alias_str = room_alias.to_string()
if not alias_event or alias_event.content.get("alias", "") != alias_str:
return
msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
"state_key": "",
"room_id": room_id,
"sender": user_id,
"content": {},
},
ratelimit=False
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_association_from_room_alias(self, room_alias): def get_association_from_room_alias(self, room_alias):
@ -257,3 +307,13 @@ class DirectoryHandler(BaseHandler):
return return
# either no interested services, or no service with an exclusive lock # either no interested services, or no service with an exclusive lock
defer.returnValue(True) defer.returnValue(True)
@defer.inlineCallbacks
def _user_can_delete_alias(self, alias, user_id):
creator = yield self.store.get_room_alias_creator(alias.to_string())
if creator and creator == user_id:
defer.returnValue(True)
is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
defer.returnValue(is_admin)

View File

@ -18,6 +18,8 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.types import UserID from synapse.types import UserID
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.api.constants import Membership, EventTypes
from synapse.events import EventBase
from ._base import BaseHandler from ._base import BaseHandler
@ -28,14 +30,6 @@ import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def started_user_eventstream(distributor, user):
return distributor.fire("started_user_eventstream", user)
def stopped_user_eventstream(distributor, user):
return distributor.fire("stopped_user_eventstream", user)
class EventStreamHandler(BaseHandler): class EventStreamHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
@ -54,61 +48,6 @@ class EventStreamHandler(BaseHandler):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@defer.inlineCallbacks
def started_stream(self, user):
"""Tells the presence handler that we have started an eventstream for
the user:
Args:
user (User): The user who started a stream.
Returns:
A deferred that completes once their presence has been updated.
"""
if user not in self._streams_per_user:
# Make sure we set the streams per user to 1 here rather than
# setting it to zero and incrementing the value below.
# Otherwise this may race with stopped_stream causing the
# user to be erased from the map before we have a chance
# to increment it.
self._streams_per_user[user] = 1
if user in self._stop_timer_per_user:
try:
self.clock.cancel_call_later(
self._stop_timer_per_user.pop(user)
)
except:
logger.exception("Failed to cancel event timer")
else:
yield started_user_eventstream(self.distributor, user)
else:
self._streams_per_user[user] += 1
def stopped_stream(self, user):
"""If there are no streams for a user this starts a timer that will
notify the presence handler that we haven't got an event stream for
the user unless the user starts a new stream in 30 seconds.
Args:
user (User): The user who stopped a stream.
"""
self._streams_per_user[user] -= 1
if not self._streams_per_user[user]:
del self._streams_per_user[user]
# 30 seconds of grace to allow the client to reconnect again
# before we think they're gone
def _later():
logger.debug("_later stopped_user_eventstream %s", user)
self._stop_timer_per_user.pop(user, None)
return stopped_user_eventstream(self.distributor, user)
logger.debug("Scheduling _later: for %s", user)
self._stop_timer_per_user[user] = (
self.clock.call_later(30, _later)
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0, def get_stream(self, auth_user_id, pagin_config, timeout=0,
@ -119,18 +58,19 @@ class EventStreamHandler(BaseHandler):
If `only_keys` is not None, events from keys will be sent down. If `only_keys` is not None, events from keys will be sent down.
""" """
auth_user = UserID.from_string(auth_user_id) auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_handlers().presence_handler
try: context = yield presence_handler.user_syncing(
if affect_presence: auth_user_id, affect_presence=affect_presence,
yield self.started_stream(auth_user) )
with context:
if timeout: if timeout:
# If they've set a timeout set a minimum limit. # If they've set a timeout set a minimum limit.
timeout = max(timeout, 500) timeout = max(timeout, 500)
# Add some randomness to this value to try and mitigate against # Add some randomness to this value to try and mitigate against
# thundering herds on restart. # thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
events, tokens = yield self.notifier.get_events_for( events, tokens = yield self.notifier.get_events_for(
auth_user, pagin_config, timeout, auth_user, pagin_config, timeout,
@ -138,6 +78,34 @@ class EventStreamHandler(BaseHandler):
is_guest=is_guest, explicit_room_id=room_id is_guest=is_guest, explicit_room_id=room_id
) )
# When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users.
to_add = []
for event in events:
if not isinstance(event, EventBase):
continue
if event.type == EventTypes.Member:
if event.membership != Membership.JOIN:
continue
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
users = yield self.store.get_users_in_room(event.room_id)
states = yield presence_handler.get_states(
users,
as_event=True,
)
to_add.extend(states)
else:
ev = yield presence_handler.get_state(
UserID.from_string(event.state_key),
as_event=True,
)
to_add.append(ev)
events.extend(to_add)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
chunks = [ chunks = [
@ -152,10 +120,6 @@ class EventStreamHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
finally:
if affect_presence:
self.stopped_stream(auth_user)
class EventHandler(BaseHandler): class EventHandler(BaseHandler):

View File

@ -14,6 +14,9 @@
# limitations under the License. # limitations under the License.
"""Contains handlers for federation events.""" """Contains handlers for federation events."""
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
from ._base import BaseHandler from ._base import BaseHandler
@ -221,19 +224,11 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user) extra_users.append(target_user)
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=extra_users extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
prev_state = context.current_state.get((event.type, event.state_key)) prev_state = context.current_state.get((event.type, event.state_key))
@ -244,12 +239,6 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id) yield user_joined_room(self.distributor, user, event.room_id)
if not backfilled and not event.internal_metadata.is_outlier():
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
event, self
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events): def _filter_events_for_server(self, server_name, room_id, events):
event_to_state = yield self.store.get_state_for_events( event_to_state = yield self.store.get_state_for_events(
@ -483,7 +472,7 @@ class FederationHandler(BaseHandler):
limit=100, limit=100,
extremities=[e for e in extremities.keys()] extremities=[e for e in extremities.keys()]
) )
except SynapseError: except SynapseError as e:
logger.info( logger.info(
"Failed to backfill from %s because %s", "Failed to backfill from %s because %s",
dom, e, dom, e,
@ -643,19 +632,11 @@ class FederationHandler(BaseHandler):
) )
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=[joinee] extra_users=[joinee]
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
logger.debug("Finished joining %s to %s", joinee, room_id) logger.debug("Finished joining %s to %s", joinee, room_id)
finally: finally:
room_queue = self.room_queues[room_id] room_queue = self.room_queues[room_id]
@ -730,18 +711,10 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user) extra_users.append(target_user)
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users event, event_stream_id, max_stream_id, extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN: if event.content["membership"] == Membership.JOIN:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
@ -811,19 +784,11 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(event.state_key) target_user = UserID.from_string(event.state_key)
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=[target_user], extra_users=[target_user],
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -848,7 +813,22 @@ class FederationHandler(BaseHandler):
target_hosts, target_hosts,
signed_event signed_event
) )
defer.returnValue(None)
context = yield self.state_handler.compute_event_context(event)
event_stream_id, max_stream_id = yield self.store.persist_event(
event,
context=context,
backfilled=False,
)
target_user = UserID.from_string(event.state_key)
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=[target_user],
)
defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
@ -948,18 +928,10 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user) extra_users.append(target_user)
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users event, event_stream_id, max_stream_id, extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
new_pdu = event new_pdu = event
destinations = set() destinations = set()
@ -1113,6 +1085,12 @@ class FederationHandler(BaseHandler):
auth_events=auth_events, auth_events=auth_events,
) )
if not backfilled and not event.internal_metadata.is_outlier():
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
event, context, self
)
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, event,
context=context, context=context,
@ -1186,7 +1164,13 @@ class FederationHandler(BaseHandler):
try: try:
self.auth.check(e, auth_events=auth_for_e) self.auth.check(e, auth_events=auth_for_e)
except AuthError as err: except SynapseError as err:
# we may get SynapseErrors here as well as AuthErrors. For
# instance, there are a couple of (ancient) events in some
# rooms whose senders do not have the correct sigil; these
# cause SynapseErrors in auth.check. We don't want to give up
# the attempt to federate altogether in such cases.
logger.warn( logger.warn(
"Rejecting %s because %s", "Rejecting %s because %s",
e.event_id, err.msg e.event_id, err.msg
@ -1654,19 +1638,15 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def exchange_third_party_invite(self, invite): def exchange_third_party_invite(
sender = invite["sender"] self,
room_id = invite["room_id"] sender_user_id,
target_user_id,
if "signed" not in invite or "token" not in invite["signed"]: room_id,
logger.info( signed,
"Discarding received notification of third party invite " ):
"without signed: %s" % (invite,)
)
return
third_party_invite = { third_party_invite = {
"signed": invite["signed"], "signed": signed,
} }
event_dict = { event_dict = {
@ -1676,8 +1656,8 @@ class FederationHandler(BaseHandler):
"third_party_invite": third_party_invite, "third_party_invite": third_party_invite,
}, },
"room_id": room_id, "room_id": room_id,
"sender": sender, "sender": sender_user_id,
"state_key": invite["mxid"], "state_key": target_user_id,
} }
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
@ -1690,11 +1670,11 @@ class FederationHandler(BaseHandler):
) )
self.auth.check(event, context.current_state) self.auth.check(event, context.current_state)
yield self._validate_keyserver(event, auth_events=context.current_state) yield self._check_signature(event, auth_events=context.current_state)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(event, context) yield member_handler.send_membership_event(None, event, context)
else: else:
destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)]) destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
yield self.replication_layer.forward_third_party_invite( yield self.replication_layer.forward_third_party_invite(
destinations, destinations,
room_id, room_id,
@ -1715,13 +1695,13 @@ class FederationHandler(BaseHandler):
) )
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
yield self._validate_keyserver(event, auth_events=context.current_state) yield self._check_signature(event, auth_events=context.current_state)
returned_invite = yield self.send_invite(origin, event) returned_invite = yield self.send_invite(origin, event)
# TODO: Make sure the signatures actually are correct. # TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures) event.signatures.update(returned_invite.signatures)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(event, context) yield member_handler.send_membership_event(None, event, context)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_display_name_to_third_party_invite(self, event_dict, event, context): def add_display_name_to_third_party_invite(self, event_dict, event, context):
@ -1745,17 +1725,69 @@ class FederationHandler(BaseHandler):
defer.returnValue((event, context)) defer.returnValue((event, context))
@defer.inlineCallbacks @defer.inlineCallbacks
def _validate_keyserver(self, event, auth_events): def _check_signature(self, event, auth_events):
token = event.content["third_party_invite"]["signed"]["token"] """
Checks that the signature in the event is consistent with its invite.
:param event (Event): The m.room.member event to check
:param auth_events (dict<(event type, state_key), event>)
:raises
AuthError if signature didn't match any keys, or key has been
revoked,
SynapseError if a transient error meant a key couldn't be checked
for revocation.
"""
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]
invite_event = auth_events.get( invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,) (EventTypes.ThirdPartyInvite, token,)
) )
if not invite_event:
raise AuthError(403, "Could not find invite")
last_exception = None
for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
try:
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"):
continue
public_key = public_key_object["public_key"]
verify_key = decode_verify_key_bytes(
key_name,
decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
if "key_validity_url" in public_key_object:
yield self._check_key_revocation(
public_key,
public_key_object["key_validity_url"]
)
return
except Exception as e:
last_exception = e
raise last_exception
@defer.inlineCallbacks
def _check_key_revocation(self, public_key, url):
"""
Checks whether public_key has been revoked.
:param public_key (str): base-64 encoded public key.
:param url (str): Key revocation URL.
:raises
AuthError if they key has been revoked.
SynapseError if a transient error meant a key couldn't be checked
for revocation.
"""
try: try:
response = yield self.hs.get_simple_http_client().get_json( response = yield self.hs.get_simple_http_client().get_json(
invite_event.content["key_validity_url"], url,
{"public_key": invite_event.content["public_key"]} {"public_key": public_key}
) )
except Exception: except Exception:
raise SynapseError( raise SynapseError(

View File

@ -36,14 +36,15 @@ class IdentityHandler(BaseHandler):
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers)
self.trust_any_id_server_just_for_testing_do_not_use = (
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
)
@defer.inlineCallbacks @defer.inlineCallbacks
def threepid_from_creds(self, creds): def threepid_from_creds(self, creds):
yield run_on_reactor() yield run_on_reactor()
# XXX: make this configurable!
# trustedIdServers = ['matrix.org', 'localhost:8090']
trustedIdServers = ['matrix.org', 'vector.im']
if 'id_server' in creds: if 'id_server' in creds:
id_server = creds['id_server'] id_server = creds['id_server']
elif 'idServer' in creds: elif 'idServer' in creds:
@ -58,7 +59,16 @@ class IdentityHandler(BaseHandler):
else: else:
raise SynapseError(400, "No client_secret in creds") raise SynapseError(400, "No client_secret in creds")
if id_server not in trustedIdServers: if id_server not in self.trusted_id_servers:
if self.trust_any_id_server_just_for_testing_do_not_use:
logger.warn(
"Trusting untrustworthy ID server %r even though it isn't"
" in the trusted id list for testing because"
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
" is set in the config",
id_server,
)
else:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' + logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', id_server) 'credentials', id_server)
defer.returnValue(None) defer.returnValue(None)

View File

@ -16,12 +16,11 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError, AuthError, Codes from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.types import UserID, RoomStreamToken, StreamToken from synapse.types import UserID, RoomStreamToken, StreamToken
@ -105,8 +104,6 @@ class MessageHandler(BaseHandler):
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
room_token = RoomStreamToken.parse(room_token) room_token = RoomStreamToken.parse(room_token)
if room_token.topological is None:
raise SynapseError(400, "Invalid token")
pagin_config.from_token = pagin_config.from_token.copy_and_replace( pagin_config.from_token = pagin_config.from_token.copy_and_replace(
"room_key", str(room_token) "room_key", str(room_token)
@ -117,26 +114,30 @@ class MessageHandler(BaseHandler):
membership, member_event_id = yield self._check_in_room_or_world_readable( membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id room_id, user_id
) )
if source_config.direction == 'b':
# if we're going backwards, we might need to backfill. This
# requires that we have a topo token.
if room_token.topological:
max_topo = room_token.topological
else:
max_topo = yield self.store.get_max_topological_token_for_stream_and_room(
room_id, room_token.stream
)
if membership == Membership.LEAVE: if membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before # If they have left the room then clamp the token to be before
# they left the room. # they left the room, to save the effort of loading from the
# database.
leave_token = yield self.store.get_topological_token_for_event( leave_token = yield self.store.get_topological_token_for_event(
member_event_id member_event_id
) )
leave_token = RoomStreamToken.parse(leave_token) leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < room_token.topological: if leave_token.topological < max_topo:
source_config.from_key = str(leave_token) source_config.from_key = str(leave_token)
if source_config.direction == "f":
if source_config.to_key is None:
source_config.to_key = str(leave_token)
else:
to_token = RoomStreamToken.parse(source_config.to_key)
if leave_token.topological < to_token.topological:
source_config.to_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill( yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, room_token.topological room_id, max_topo
) )
events, next_key = yield data_source.get_pagination_rows( events, next_key = yield data_source.get_pagination_rows(
@ -195,11 +196,24 @@ class MessageHandler(BaseHandler):
if builder.type == EventTypes.Member: if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None) membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
if membership == Membership.JOIN: if membership == Membership.JOIN:
joinee = UserID.from_string(builder.state_key)
# If event doesn't include a display name, add one. # If event doesn't include a display name, add one.
yield collect_presencelike_data( yield collect_presencelike_data(
self.distributor, joinee, builder.content self.distributor, target, builder.content
)
elif membership == Membership.INVITE:
profile = self.hs.get_handlers().profile_handler
content = builder.content
try:
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s",
target, e
) )
if token_id is not None: if token_id is not None:
@ -214,7 +228,7 @@ class MessageHandler(BaseHandler):
defer.returnValue((event, context)) defer.returnValue((event, context))
@defer.inlineCallbacks @defer.inlineCallbacks
def send_event(self, event, context, ratelimit=True, is_guest=False): def send_nonmember_event(self, requester, event, context, ratelimit=True):
""" """
Persists and notifies local clients and federation of an event. Persists and notifies local clients and federation of an event.
@ -224,55 +238,70 @@ class MessageHandler(BaseHandler):
ratelimit (bool): Whether to rate limit this send. ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest. is_guest (bool): Whether the sender is a guest.
""" """
if event.type == EventTypes.Member:
raise SynapseError(
500,
"Tried to send member event through non-member codepath"
)
user = UserID.from_string(event.sender) user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,) assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if ratelimit:
self.ratelimit(event.sender)
if event.is_state(): if event.is_state():
prev_state = context.current_state.get((event.type, event.state_key)) prev_state = self.deduplicate_state_event(event, context)
if prev_state and event.user_id == prev_state.user_id: if prev_state is not None:
prev_content = encode_canonical_json(prev_state.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
# Duplicate suppression for state updates with same sender
# and content.
defer.returnValue(prev_state) defer.returnValue(prev_state)
if event.type == EventTypes.Member:
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(event, context, is_guest=is_guest)
else:
yield self.handle_new_client_event( yield self.handle_new_client_event(
requester=requester,
event=event, event=event,
context=context, context=context,
ratelimit=ratelimit,
) )
if event.type == EventTypes.Message: if event.type == EventTypes.Message:
presence = self.hs.get_handlers().presence_handler presence = self.hs.get_handlers().presence_handler
with PreserveLoggingContext(): yield presence.bump_presence_active_time(user)
presence.bump_presence_active_time(user)
def deduplicate_state_event(self, event, context):
"""
Checks whether event is in the latest resolved state in context.
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_event = context.current_state.get((event.type, event.state_key))
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
return prev_event
return None
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True, def create_and_send_nonmember_event(
token_id=None, txn_id=None, is_guest=False): self,
requester,
event_dict,
ratelimit=True,
txn_id=None
):
""" """
Creates an event, then sends it. Creates an event, then sends it.
See self.create_event and self.send_event. See self.create_event and self.send_nonmember_event.
""" """
event, context = yield self.create_event( event, context = yield self.create_event(
event_dict, event_dict,
token_id=token_id, token_id=requester.access_token_id,
txn_id=txn_id txn_id=txn_id
) )
yield self.send_event( yield self.send_nonmember_event(
requester,
event, event,
context, context,
ratelimit=ratelimit, ratelimit=ratelimit,
is_guest=is_guest
) )
defer.returnValue(event) defer.returnValue(event)
@ -633,8 +662,8 @@ class MessageHandler(BaseHandler):
user_id, messages, is_peeking=is_peeking user_id, messages, is_peeking=is_peeking
) )
start_token = StreamToken(token[0], 0, 0, 0, 0) start_token = StreamToken.START.copy_and_replace("room_key", token[0])
end_token = StreamToken(token[1], 0, 0, 0, 0) end_token = StreamToken.START.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
@ -658,10 +687,6 @@ class MessageHandler(BaseHandler):
room_id=room_id, room_id=room_id,
) )
# TODO(paul): I wish I was called with user objects not user_id
# strings...
auth_user = UserID.from_string(user_id)
# TODO: These concurrently # TODO: These concurrently
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
state = [ state = [
@ -686,13 +711,11 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_presence(): def get_presence():
states = yield presence_handler.get_states( states = yield presence_handler.get_states(
target_users=[UserID.from_string(m.user_id) for m in room_members], [m.user_id for m in room_members],
auth_user=auth_user,
as_event=True, as_event=True,
check_auth=False,
) )
defer.returnValue(states.values()) defer.returnValue(states)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_receipts(): def get_receipts():

File diff suppressed because it is too large Load Diff

View File

@ -16,8 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.api.constants import EventTypes, Membership from synapse.types import UserID, Requester
from synapse.types import UserID
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from ._base import BaseHandler from ._base import BaseHandler
@ -49,6 +48,9 @@ class ProfileHandler(BaseHandler):
distributor = hs.get_distributor() distributor = hs.get_distributor()
self.distributor = distributor self.distributor = distributor
distributor.declare("collect_presencelike_data")
distributor.declare("changed_presencelike_data")
distributor.observe("registered_user", self.registered_user) distributor.observe("registered_user", self.registered_user)
distributor.observe( distributor.observe(
@ -87,13 +89,13 @@ class ProfileHandler(BaseHandler):
defer.returnValue(result["displayname"]) defer.returnValue(result["displayname"])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_displayname(self, target_user, auth_user, new_displayname): def set_displayname(self, target_user, requester, new_displayname):
"""target_user is the user whose displayname is to be changed; """target_user is the user whose displayname is to be changed;
auth_user is the user attempting to make this change.""" auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user: if target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname") raise AuthError(400, "Cannot set another user's displayname")
if new_displayname == '': if new_displayname == '':
@ -107,7 +109,7 @@ class ProfileHandler(BaseHandler):
"displayname": new_displayname, "displayname": new_displayname,
}) })
yield self._update_join_states(target_user) yield self._update_join_states(requester)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_avatar_url(self, target_user): def get_avatar_url(self, target_user):
@ -137,13 +139,13 @@ class ProfileHandler(BaseHandler):
defer.returnValue(result["avatar_url"]) defer.returnValue(result["avatar_url"])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_avatar_url(self, target_user, auth_user, new_avatar_url): def set_avatar_url(self, target_user, requester, new_avatar_url):
"""target_user is the user whose avatar_url is to be changed; """target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change.""" auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user: if target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url") raise AuthError(400, "Cannot set another user's avatar_url")
yield self.store.set_profile_avatar_url( yield self.store.set_profile_avatar_url(
@ -154,7 +156,7 @@ class ProfileHandler(BaseHandler):
"avatar_url": new_avatar_url, "avatar_url": new_avatar_url,
}) })
yield self._update_join_states(target_user) yield self._update_join_states(requester)
@defer.inlineCallbacks @defer.inlineCallbacks
def collect_presencelike_data(self, user, state): def collect_presencelike_data(self, user, state):
@ -197,32 +199,30 @@ class ProfileHandler(BaseHandler):
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
def _update_join_states(self, user): def _update_join_states(self, requester):
user = requester.user
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
return return
self.ratelimit(user.to_string()) self.ratelimit(requester)
joins = yield self.store.get_rooms_for_user( joins = yield self.store.get_rooms_for_user(
user.to_string(), user.to_string(),
) )
for j in joins: for j in joins:
content = { handler = self.hs.get_handlers().room_member_handler
"membership": Membership.JOIN,
}
yield collect_presencelike_data(self.distributor, user, content)
msg_handler = self.hs.get_handlers().message_handler
try: try:
yield msg_handler.create_and_send_event({ # Assume the user isn't a guest because we don't let guests set
"type": EventTypes.Member, # profile or avatar data.
"room_id": j.room_id, requester = Requester(user, "", False)
"state_key": user.to_string(), yield handler.update_membership(
"content": content, requester,
"sender": user.to_string() user,
}, ratelimit=False) j.room_id,
"join", # We treat a profile update like a join.
ratelimit=False, # Try to hide that these events aren't atomic.
)
except Exception as e: except Exception as e:
logger.warn( logger.warn(
"Failed to update join event for room %s - %s", "Failed to update join event for room %s - %s",

View File

@ -36,8 +36,6 @@ class ReceiptsHandler(BaseHandler):
) )
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self._receipt_cache = None
@defer.inlineCallbacks @defer.inlineCallbacks
def received_client_receipt(self, room_id, receipt_type, user_id, def received_client_receipt(self, room_id, receipt_type, user_id,
event_id): event_id):

View File

@ -21,7 +21,6 @@ from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
) )
from ._base import BaseHandler from ._base import BaseHandler
import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
@ -45,21 +44,33 @@ class RegistrationHandler(BaseHandler):
self.distributor.declare("registered_user") self.distributor.declare("registered_user")
self.captcha_client = CaptchaServerHttpClient(hs) self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None
@defer.inlineCallbacks @defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None): def check_username(self, localpart, guest_access_token=None,
assigned_user_id=None):
yield run_on_reactor() yield run_on_reactor()
if urllib.quote(localpart.encode('utf-8')) != localpart: if urllib.quote(localpart.encode('utf-8')) != localpart:
raise SynapseError( raise SynapseError(
400, 400,
"User ID can only contain characters a-z, 0-9, or '-./'", "User ID can only contain characters a-z, 0-9, or '_-./'",
Codes.INVALID_USERNAME Codes.INVALID_USERNAME
) )
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id) if assigned_user_id:
if user_id == assigned_user_id:
return
else:
raise SynapseError(
400,
"A different user ID has already been registered for this session",
)
yield self.check_user_id_not_appservice_exclusive(user_id)
users = yield self.store.get_users_by_id_case_insensitive(user_id) users = yield self.store.get_users_by_id_case_insensitive(user_id)
if users: if users:
@ -91,7 +102,7 @@ class RegistrationHandler(BaseHandler):
Args: Args:
localpart : The local part of the user ID to register. If None, localpart : The local part of the user ID to register. If None,
one will be randomly generated. one will be generated.
password (str) : The password to assign to this user so they can password (str) : The password to assign to this user so they can
login again. This can be None which means they cannot login again login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user). via a password (e.g. the user is an application service user).
@ -108,6 +119,18 @@ class RegistrationHandler(BaseHandler):
if localpart: if localpart:
yield self.check_username(localpart, guest_access_token=guest_access_token) yield self.check_username(localpart, guest_access_token=guest_access_token)
was_guest = guest_access_token is not None
if not was_guest:
try:
int(localpart)
raise RegistrationError(
400,
"Numeric user IDs are reserved for guest users."
)
except ValueError:
pass
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
@ -118,38 +141,37 @@ class RegistrationHandler(BaseHandler):
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash, password_hash=password_hash,
was_guest=guest_access_token is not None, was_guest=was_guest,
make_guest=make_guest, make_guest=make_guest,
) )
yield registered_user(self.distributor, user) yield registered_user(self.distributor, user)
else: else:
# autogen a random user ID # autogen a sequential user ID
attempts = 0 attempts = 0
user_id = None
token = None token = None
while not user_id: user = None
try: while not user:
localpart = self._generate_user_id() localpart = yield self._generate_user_id(attempts > 0)
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_not_appservice_exclusive(user_id)
if generate_token: if generate_token:
token = self.auth_handler().generate_access_token(user_id) token = self.auth_handler().generate_access_token(user_id)
try:
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash) password_hash=password_hash,
make_guest=make_guest
yield registered_user(self.distributor, user) )
except SynapseError: except SynapseError:
# if user id is taken, just generate another # if user id is taken, just generate another
user = None
user_id = None user_id = None
token = None token = None
attempts += 1 attempts += 1
if attempts > 5: yield registered_user(self.distributor, user)
raise RegistrationError(
500, "Cannot generate user ID.")
# We used to generate default identicons here, but nowadays # We used to generate default identicons here, but nowadays
# we want clients to generate their own as part of their branding # we want clients to generate their own as part of their branding
@ -169,13 +191,21 @@ class RegistrationHandler(BaseHandler):
400, "Invalid user localpart for this application service.", 400, "Invalid user localpart for this application service.",
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
service_id = service.id if service.is_exclusive_user(user_id) else None
yield self.check_user_id_not_appservice_exclusive(
user_id, allowed_appservice=service
)
token = self.auth_handler().generate_access_token(user_id) token = self.auth_handler().generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash="" password_hash="",
appservice_id=service_id,
) )
registered_user(self.distributor, user) yield registered_user(self.distributor, user)
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -215,7 +245,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_not_appservice_exclusive(user_id)
token = self.auth_handler().generate_access_token(user_id) token = self.auth_handler().generate_access_token(user_id)
try: try:
yield self.store.register( yield self.store.register(
@ -224,7 +254,7 @@ class RegistrationHandler(BaseHandler):
password_hash=None password_hash=None
) )
yield registered_user(self.distributor, user) yield registered_user(self.distributor, user)
except Exception, e: except Exception as e:
yield self.store.add_access_token_to_user(user_id, token) yield self.store.add_access_token_to_user(user_id, token)
# Ignore Registration errors # Ignore Registration errors
logger.exception(e) logger.exception(e)
@ -267,12 +297,14 @@ class RegistrationHandler(BaseHandler):
yield identity_handler.bind_threepid(c, user_id) yield identity_handler.bind_threepid(c, user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_user_id_is_valid(self, user_id): def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
# valid user IDs must not clash with any user ID namespaces claimed by # valid user IDs must not clash with any user ID namespaces claimed by
# application services. # application services.
services = yield self.store.get_app_services() services = yield self.store.get_app_services()
interested_services = [ interested_services = [
s for s in services if s.is_interested_in_user(user_id) s for s in services
if s.is_interested_in_user(user_id)
and s != allowed_appservice
] ]
for service in interested_services: for service in interested_services:
if service.is_exclusive_user(user_id): if service.is_exclusive_user(user_id):
@ -281,8 +313,16 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
def _generate_user_id(self): @defer.inlineCallbacks
return "-" + stringutils.random_string(18) def _generate_user_id(self, reseed=False):
if reseed or self._next_generated_user_id is None:
self._next_generated_user_id = (
yield self.store.find_next_generated_user_id_localpart()
)
id = self._next_generated_user_id
self._next_generated_user_id += 1
defer.returnValue(str(id))
@defer.inlineCallbacks @defer.inlineCallbacks
def _validate_captcha(self, ip_addr, private_key, challenge, response): def _validate_captcha(self, ip_addr, private_key, challenge, response):
@ -323,3 +363,18 @@ class RegistrationHandler(BaseHandler):
def auth_handler(self): def auth_handler(self):
return self.hs.get_handlers().auth_handler return self.hs.get_handlers().auth_handler
@defer.inlineCallbacks
def guest_access_token_for(self, medium, address, inviter_user_id):
access_token = yield self.store.get_3pid_guest_access_token(medium, address)
if access_token:
defer.returnValue(access_token)
_, access_token = yield self.register(
generate_token=True,
make_guest=True
)
access_token = yield self.store.save_or_get_3pid_guest_access_token(
medium, address, access_token, inviter_user_id
)
defer.returnValue(access_token)

File diff suppressed because it is too large Load Diff

View File

@ -18,18 +18,22 @@ from ._base import BaseHandler
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.metrics import Measure
from synapse.push.clientformat import format_push_rules_for_user
from twisted.internet import defer from twisted.internet import defer
import collections import collections
import logging import logging
import itertools
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SyncConfig = collections.namedtuple("SyncConfig", [ SyncConfig = collections.namedtuple("SyncConfig", [
"user", "user",
"filter", "filter_collection",
"is_guest", "is_guest",
]) ])
@ -72,7 +76,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
) )
class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [ class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
"room_id", # str "room_id", # str
"timeline", # TimelineBatch "timeline", # TimelineBatch
"state", # dict[(str, str), FrozenEvent] "state", # dict[(str, str), FrozenEvent]
@ -118,7 +122,11 @@ class SyncResult(collections.namedtuple("SyncResult", [
events. events.
""" """
return bool( return bool(
self.presence or self.joined or self.invited or self.archived self.presence or
self.joined or
self.invited or
self.archived or
self.account_data
) )
@ -139,11 +147,21 @@ class SyncHandler(BaseHandler):
A Deferred SyncResult. A Deferred SyncResult.
""" """
context = LoggingContext.current_context()
if context:
if since_token is None:
context.tag = "initial_sync"
elif full_state:
context.tag = "full_state_sync"
else:
context.tag = "incremental_sync"
if timeout == 0 or since_token is None or full_state: if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling # we are going to return immediately, so don't bother calling
# notifier.wait_for_events. # notifier.wait_for_events.
result = yield self.current_sync_for_user(sync_config, since_token, result = yield self.current_sync_for_user(
full_state=full_state) sync_config, since_token, full_state=full_state,
)
defer.returnValue(result) defer.returnValue(result)
else: else:
def current_sync_callback(before_token, after_token): def current_sync_callback(before_token, after_token):
@ -151,7 +169,7 @@ class SyncHandler(BaseHandler):
result = yield self.notifier.wait_for_events( result = yield self.notifier.wait_for_events(
sync_config.user.to_string(), timeout, current_sync_callback, sync_config.user.to_string(), timeout, current_sync_callback,
from_token=since_token from_token=since_token,
) )
defer.returnValue(result) defer.returnValue(result)
@ -166,18 +184,6 @@ class SyncHandler(BaseHandler):
else: else:
return self.incremental_sync_with_gap(sync_config, since_token) return self.incremental_sync_with_gap(sync_config, since_token)
def last_read_event_id_for_room_and_user(self, room_id, user_id, ephemeral_by_room):
if room_id not in ephemeral_by_room:
return None
for e in ephemeral_by_room[room_id]:
if e['type'] != 'm.receipt':
continue
for receipt_event_id, val in e['content'].items():
if 'm.read' in val:
if user_id in val['m.read']:
return receipt_event_id
return None
@defer.inlineCallbacks @defer.inlineCallbacks
def full_state_sync(self, sync_config, timeline_since_token): def full_state_sync(self, sync_config, timeline_since_token):
"""Get a sync for a client which is starting without any state. """Get a sync for a client which is starting without any state.
@ -204,9 +210,9 @@ class SyncHandler(BaseHandler):
key=None key=None
) )
membership_list = (Membership.INVITE, Membership.JOIN) membership_list = (
if sync_config.filter.include_leave: Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN
membership_list += (Membership.LEAVE, Membership.BAN) )
room_list = yield self.store.get_rooms_for_user_where_membership_is( room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=sync_config.user.to_string(), user_id=sync_config.user.to_string(),
@ -219,6 +225,10 @@ class SyncHandler(BaseHandler):
) )
) )
account_data['m.push_rules'] = yield self.push_rules_for_user(
sync_config.user
)
tags_by_room = yield self.store.get_tags_for_user( tags_by_room = yield self.store.get_tags_for_user(
sync_config.user.to_string() sync_config.user.to_string()
) )
@ -227,9 +237,14 @@ class SyncHandler(BaseHandler):
invited = [] invited = []
archived = [] archived = []
deferreds = [] deferreds = []
for event in room_list:
room_list_chunks = [room_list[i:i + 10] for i in xrange(0, len(room_list), 10)]
for room_list_chunk in room_list_chunks:
for event in room_list_chunk:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
room_sync_deferred = self.full_state_sync_for_joined_room( room_sync_deferred = preserve_fn(
self.full_state_sync_for_joined_room
)(
room_id=event.room_id, room_id=event.room_id,
sync_config=sync_config, sync_config=sync_config,
now_token=now_token, now_token=now_token,
@ -247,10 +262,18 @@ class SyncHandler(BaseHandler):
invite=invite, invite=invite,
)) ))
elif event.membership in (Membership.LEAVE, Membership.BAN): elif event.membership in (Membership.LEAVE, Membership.BAN):
# Always send down rooms we were banned or kicked from.
if not sync_config.filter_collection.include_leave:
if event.membership == Membership.LEAVE:
if sync_config.user.to_string() == event.sender:
continue
leave_token = now_token.copy_and_replace( leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,) "room_key", "s%d" % (event.stream_ordering,)
) )
room_sync_deferred = self.full_state_sync_for_archived_room( room_sync_deferred = preserve_fn(
self.full_state_sync_for_archived_room
)(
sync_config=sync_config, sync_config=sync_config,
room_id=event.room_id, room_id=event.room_id,
leave_event_id=event.event_id, leave_event_id=event.event_id,
@ -266,9 +289,17 @@ class SyncHandler(BaseHandler):
deferreds, consumeErrors=True deferreds, consumeErrors=True
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data)
)
presence = sync_config.filter_collection.filter_presence(
presence
)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=presence, presence=presence,
account_data=self.account_data_for_user(account_data), account_data=account_data_for_user,
joined=joined, joined=joined,
invited=invited, invited=invited,
archived=archived, archived=archived,
@ -289,29 +320,26 @@ class SyncHandler(BaseHandler):
room_id, sync_config, now_token, since_token=timeline_since_token room_id, sync_config, now_token, since_token=timeline_since_token
) )
notifs = yield self.unread_notifs_for_room_id( room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, ephemeral_by_room room_id, sync_config,
now_token=now_token,
since_token=timeline_since_token,
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
batch=batch,
full_state=True,
) )
unread_notifications = {} defer.returnValue(room_sync)
if notifs is not None:
unread_notifications["notification_count"] = len(notifs)
unread_notifications["highlight_count"] = len([
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
current_state = yield self.get_state_at(room_id, now_token) @defer.inlineCallbacks
def push_rules_for_user(self, user):
defer.returnValue(JoinedSyncResult( user_id = user.to_string()
room_id=room_id, rawrules = yield self.store.get_push_rules_for_user(user_id)
timeline=batch, enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
state=current_state, rules = format_push_rules_for_user(user, rawrules, enabled_map)
ephemeral=ephemeral_by_room.get(room_id, []), defer.returnValue(rules)
account_data=self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
),
unread_notifications=unread_notifications,
))
def account_data_for_user(self, account_data): def account_data_for_user(self, account_data):
account_data_events = [] account_data_events = []
@ -356,6 +384,7 @@ class SyncHandler(BaseHandler):
typing events for that room. typing events for that room.
""" """
with Measure(self.clock, "ephemeral_by_room"):
typing_key = since_token.typing_key if since_token else "0" typing_key = since_token.typing_key if since_token else "0"
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string()) rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
@ -365,7 +394,7 @@ class SyncHandler(BaseHandler):
typing, typing_key = yield typing_source.get_new_events( typing, typing_key = yield typing_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=typing_key, from_key=typing_key,
limit=sync_config.filter.ephemeral_limit(), limit=sync_config.filter_collection.ephemeral_limit(),
room_ids=room_ids, room_ids=room_ids,
is_guest=sync_config.is_guest, is_guest=sync_config.is_guest,
) )
@ -388,7 +417,7 @@ class SyncHandler(BaseHandler):
receipts, receipt_key = yield receipt_source.get_new_events( receipts, receipt_key = yield receipt_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=receipt_key, from_key=receipt_key,
limit=sync_config.filter.ephemeral_limit(), limit=sync_config.filter_collection.ephemeral_limit(),
room_ids=room_ids, room_ids=room_ids,
is_guest=sync_config.is_guest, is_guest=sync_config.is_guest,
) )
@ -403,31 +432,20 @@ class SyncHandler(BaseHandler):
defer.returnValue((now_token, ephemeral_by_room)) defer.returnValue((now_token, ephemeral_by_room))
@defer.inlineCallbacks
def full_state_sync_for_archived_room(self, room_id, sync_config, def full_state_sync_for_archived_room(self, room_id, sync_config,
leave_event_id, leave_token, leave_event_id, leave_token,
timeline_since_token, tags_by_room, timeline_since_token, tags_by_room,
account_data_by_room): account_data_by_room):
"""Sync a room for a client which is starting without any state """Sync a room for a client which is starting without any state
Returns: Returns:
A Deferred JoinedSyncResult. A Deferred ArchivedSyncResult.
""" """
batch = yield self.load_filtered_recents( return self.incremental_sync_for_archived_room(
room_id, sync_config, leave_token, since_token=timeline_since_token sync_config, room_id, leave_event_id, timeline_since_token, tags_by_room,
account_data_by_room, full_state=True, leave_token=leave_token,
) )
leave_state = yield self.store.get_state_for_event(leave_event_id)
defer.returnValue(ArchivedSyncResult(
room_id=room_id,
timeline=batch,
state=leave_state,
account_data=self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
),
))
@defer.inlineCallbacks @defer.inlineCallbacks
def incremental_sync_with_gap(self, sync_config, since_token): def incremental_sync_with_gap(self, sync_config, since_token):
""" Get the incremental delta needed to bring the client up to """ Get the incremental delta needed to bring the client up to
@ -444,19 +462,12 @@ class SyncHandler(BaseHandler):
presence, presence_key = yield presence_source.get_new_events( presence, presence_key = yield presence_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=since_token.presence_key, from_key=since_token.presence_key,
limit=sync_config.filter.presence_limit(), limit=sync_config.filter_collection.presence_limit(),
room_ids=room_ids, room_ids=room_ids,
is_guest=sync_config.is_guest, is_guest=sync_config.is_guest,
) )
now_token = now_token.copy_and_replace("presence_key", presence_key) now_token = now_token.copy_and_replace("presence_key", presence_key)
# We now fetch all ephemeral events for this room in order to get
# this users current read receipt. This could almost certainly be
# optimised.
_, all_ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token
)
now_token, ephemeral_by_room = yield self.ephemeral_by_room( now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token, since_token sync_config, now_token, since_token
) )
@ -473,139 +484,169 @@ class SyncHandler(BaseHandler):
sync_config.user sync_config.user
) )
timeline_limit = sync_config.filter.timeline_limit() user_id = sync_config.user.to_string()
room_events, _ = yield self.store.get_room_events_stream( timeline_limit = sync_config.filter_collection.timeline_limit()
sync_config.user.to_string(),
from_key=since_token.room_key,
to_key=now_token.room_key,
limit=timeline_limit + 1,
)
tags_by_room = yield self.store.get_updated_tags( tags_by_room = yield self.store.get_updated_tags(
sync_config.user.to_string(), user_id,
since_token.account_data_key, since_token.account_data_key,
) )
account_data, account_data_by_room = ( account_data, account_data_by_room = (
yield self.store.get_updated_account_data_for_user( yield self.store.get_updated_account_data_for_user(
sync_config.user.to_string(), user_id,
since_token.account_data_key, since_token.account_data_key,
) )
) )
joined = [] push_rules_changed = yield self.store.have_push_rules_changed_for_user(
user_id, int(since_token.push_rules_key)
)
if push_rules_changed:
account_data["m.push_rules"] = yield self.push_rules_for_user(
sync_config.user
)
# Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
)
mem_change_events_by_room_id = {}
for event in rooms_changed:
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
newly_joined_rooms = []
archived = [] archived = []
if len(room_events) <= timeline_limit: invited = []
# There is no gap in any of the rooms. Therefore we can just for room_id, events in mem_change_events_by_room_id.items():
# partition the new events by room and return them. non_joins = [e for e in events if e.membership != Membership.JOIN]
logger.debug("Got %i events for incremental sync - not limited", has_join = len(non_joins) != len(events)
len(room_events))
invite_events = [] # We want to figure out if we joined the room at some point since
leave_events = [] # the last sync (even if we have since left). This is to make sure
events_by_room_id = {} # we do send down the room, and with full state, where necessary
for event in room_events: if room_id in joined_room_ids or has_join:
events_by_room_id.setdefault(event.room_id, []).append(event) old_state = yield self.get_state_at(room_id, since_token)
if event.room_id not in joined_room_ids: old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
if (event.type == EventTypes.Member if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
and event.state_key == sync_config.user.to_string()): newly_joined_rooms.append(room_id)
if event.membership == Membership.INVITE:
invite_events.append(event)
elif event.membership in (Membership.LEAVE, Membership.BAN):
leave_events.append(event)
for room_id in joined_room_ids: if room_id in joined_room_ids:
recents = events_by_room_id.get(room_id, []) continue
logger.debug("Events for room %s: %r", room_id, recents)
state = {
(event.type, event.state_key): event
for event in recents if event.is_state()}
limited = False
if recents: if not non_joins:
prev_batch = now_token.copy_and_replace( continue
"room_key", recents[0].internal_metadata.before
)
else:
prev_batch = now_token
just_joined = yield self.check_joined_room(sync_config, state)
if just_joined:
logger.debug("User has just joined %s: needs full state",
room_id)
state = yield self.get_state_at(room_id, now_token)
# the timeline is inherently limited if we've just joined
limited = True
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=TimelineBatch(
events=recents,
prev_batch=prev_batch,
limited=limited,
),
state=state,
ephemeral=ephemeral_by_room.get(room_id, []),
account_data=self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
),
unread_notifications={},
)
logger.debug("Result for room %s: %r", room_id, room_sync)
# Only bother if we're still currently invited
should_invite = non_joins[-1].membership == Membership.INVITE
if should_invite:
room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
if room_sync: if room_sync:
notifs = yield self.unread_notifs_for_room_id( invited.append(room_sync)
room_id, sync_config, all_ephemeral_by_room
)
if notifs is not None: # Always include leave/ban events. Just take the last one.
notif_dict = room_sync.unread_notifications # TODO: How do we handle ban -> leave in same batch?
notif_dict["notification_count"] = len(notifs) leave_events = [
notif_dict["highlight_count"] = len([ e for e in non_joins
1 for notif in notifs if e.membership in (Membership.LEAVE, Membership.BAN)
if _action_has_highlight(notif["actions"]) ]
])
joined.append(room_sync) if leave_events:
leave_event = leave_events[-1]
else:
logger.debug("Got %i events for incremental sync - hit limit",
len(room_events))
invite_events = yield self.store.get_invites_for_user(
sync_config.user.to_string()
)
leave_events = yield self.store.get_leave_and_ban_events_for_user(
sync_config.user.to_string()
)
for room_id in joined_room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token,
ephemeral_by_room, tags_by_room, account_data_by_room,
all_ephemeral_by_room=all_ephemeral_by_room,
)
if room_sync:
joined.append(room_sync)
for leave_event in leave_events:
room_sync = yield self.incremental_sync_for_archived_room( room_sync = yield self.incremental_sync_for_archived_room(
sync_config, leave_event, since_token, tags_by_room, sync_config, room_id, leave_event.event_id, since_token,
account_data_by_room tags_by_room, account_data_by_room,
full_state=room_id in newly_joined_rooms
) )
if room_sync: if room_sync:
archived.append(room_sync) archived.append(room_sync)
invited = [ # Get all events for rooms we're currently joined to.
InvitedSyncResult(room_id=event.room_id, invite=event) room_to_events = yield self.store.get_room_events_stream_for_rooms(
for event in invite_events room_ids=joined_room_ids,
] from_key=since_token.room_key,
to_key=now_token.room_key,
limit=timeline_limit + 1,
)
joined = []
# We loop through all room ids, even if there are no new events, in case
# there are non room events taht we need to notify about.
for room_id in joined_room_ids:
room_entry = room_to_events.get(room_id, None)
if room_entry:
events, start_key = room_entry
prev_batch_token = now_token.copy_and_replace("room_key", start_key)
newly_joined_room = room_id in newly_joined_rooms
full_state = newly_joined_room
batch = yield self.load_filtered_recents(
room_id, sync_config, prev_batch_token,
since_token=since_token,
recents=events,
newly_joined_room=newly_joined_room,
)
else:
batch = TimelineBatch(
events=[],
prev_batch=since_token,
limited=False,
)
full_state = False
room_sync = yield self.incremental_sync_with_gap_for_room(
room_id=room_id,
sync_config=sync_config,
since_token=since_token,
now_token=now_token,
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
batch=batch,
full_state=full_state,
)
if room_sync:
joined.append(room_sync)
# For each newly joined room, we want to send down presence of
# existing users.
presence_handler = self.hs.get_handlers().presence_handler
extra_presence_users = set()
for room_id in newly_joined_rooms:
users = yield self.store.get_users_in_room(event.room_id)
extra_presence_users.update(users)
# For each new member, send down presence.
for joined_sync in joined:
it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values())
for event in it:
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
extra_presence_users.add(event.state_key)
states = yield presence_handler.get_states(
[u for u in extra_presence_users if u != user_id],
as_event=True,
)
presence.extend(states)
account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data)
)
presence = sync_config.filter_collection.filter_presence(
presence
)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=presence, presence=presence,
account_data=self.account_data_for_user(account_data), account_data=account_data_for_user,
joined=joined, joined=joined,
invited=invited, invited=invited,
archived=archived, archived=archived,
@ -614,38 +655,56 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token, def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None): since_token=None, recents=None, newly_joined_room=False):
""" """
:returns a Deferred TimelineBatch :returns a Deferred TimelineBatch
""" """
limited = True with Measure(self.clock, "load_filtered_recents"):
recents = []
filtering_factor = 2 filtering_factor = 2
timeline_limit = sync_config.filter.timeline_limit() timeline_limit = sync_config.filter_collection.timeline_limit()
load_limit = max(timeline_limit * filtering_factor, 100) load_limit = max(timeline_limit * filtering_factor, 10)
max_repeat = 3 # Only try a few times per room, otherwise max_repeat = 5 # Only try a few times per room, otherwise
room_key = now_token.room_key room_key = now_token.room_key
end_key = room_key end_key = room_key
if recents is None or newly_joined_room or timeline_limit < len(recents):
limited = True
else:
limited = False
if recents is not None:
recents = sync_config.filter_collection.filter_room_timeline(recents)
recents = yield self._filter_events_for_client(
sync_config.user.to_string(),
recents,
)
else:
recents = []
since_key = None
if since_token and not newly_joined_room:
since_key = since_token.room_key
while limited and len(recents) < timeline_limit and max_repeat: while limited and len(recents) < timeline_limit and max_repeat:
events, keys = yield self.store.get_recent_events_for_room( events, end_key = yield self.store.get_room_events_stream_for_room(
room_id, room_id,
limit=load_limit + 1, limit=load_limit + 1,
from_token=since_token.room_key if since_token else None, from_key=since_key,
end_token=end_key, to_key=end_key,
)
loaded_recents = sync_config.filter_collection.filter_room_timeline(
events
) )
(room_key, _) = keys
end_key = "s" + room_key.split('-')[-1]
loaded_recents = sync_config.filter.filter_room_timeline(events)
loaded_recents = yield self._filter_events_for_client( loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(), sync_config.user.to_string(),
loaded_recents, loaded_recents,
is_peeking=sync_config.is_guest,
) )
loaded_recents.extend(recents) loaded_recents.extend(recents)
recents = loaded_recents recents = loaded_recents
if len(events) <= load_limit: if len(events) <= load_limit:
limited = False limited = False
break
max_repeat -= 1 max_repeat -= 1
if len(recents) > timeline_limit: if len(recents) > timeline_limit:
@ -658,7 +717,9 @@ class SyncHandler(BaseHandler):
) )
defer.returnValue(TimelineBatch( defer.returnValue(TimelineBatch(
events=recents, prev_batch=prev_batch_token, limited=limited events=recents,
prev_batch=prev_batch_token,
limited=limited or newly_joined_room
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -666,112 +727,92 @@ class SyncHandler(BaseHandler):
since_token, now_token, since_token, now_token,
ephemeral_by_room, tags_by_room, ephemeral_by_room, tags_by_room,
account_data_by_room, account_data_by_room,
all_ephemeral_by_room): batch, full_state=False):
""" Get the incremental delta needed to bring the client up to date for
the room. Gives the client the most recent events and the changes to
state.
Returns:
A Deferred JoinedSyncResult
"""
logger.debug("Doing incremental sync for room %s between %s and %s",
room_id, since_token, now_token)
# TODO(mjark): Check for redactions we might have missed.
batch = yield self.load_filtered_recents(
room_id, sync_config, now_token, since_token,
)
logger.debug("Recents %r", batch)
current_state = yield self.get_state_at(room_id, now_token)
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
state = yield self.compute_state_delta( state = yield self.compute_state_delta(
since_token=since_token, room_id, batch, sync_config, since_token, now_token,
previous_state=state_at_previous_sync, full_state=full_state
current_state=current_state,
) )
just_joined = yield self.check_joined_room(sync_config, state) account_data = self.account_data_for_room(
if just_joined: room_id, tags_by_room, account_data_by_room
state = yield self.get_state_at(room_id, now_token) )
notifs = yield self.unread_notifs_for_room_id( account_data = sync_config.filter_collection.filter_room_account_data(
room_id, sync_config, all_ephemeral_by_room account_data
)
ephemeral = sync_config.filter_collection.filter_room_ephemeral(
ephemeral_by_room.get(room_id, [])
) )
unread_notifications = {} unread_notifications = {}
if notifs is not None:
unread_notifications["notification_count"] = len(notifs)
unread_notifications["highlight_count"] = len([
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=state, state=state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral,
account_data=self.account_data_for_room( account_data=account_data,
room_id, tags_by_room, account_data_by_room
),
unread_notifications=unread_notifications, unread_notifications=unread_notifications,
) )
if room_sync:
notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config
)
if notifs is not None:
unread_notifications["notification_count"] = notifs["notify_count"]
unread_notifications["highlight_count"] = notifs["highlight_count"]
logger.debug("Room sync: %r", room_sync) logger.debug("Room sync: %r", room_sync)
defer.returnValue(room_sync) defer.returnValue(room_sync)
@defer.inlineCallbacks @defer.inlineCallbacks
def incremental_sync_for_archived_room(self, sync_config, leave_event, def incremental_sync_for_archived_room(self, sync_config, room_id, leave_event_id,
since_token, tags_by_room, since_token, tags_by_room,
account_data_by_room): account_data_by_room, full_state,
leave_token=None):
""" Get the incremental delta needed to bring the client up to date for """ Get the incremental delta needed to bring the client up to date for
the archived room. the archived room.
Returns: Returns:
A Deferred ArchivedSyncResult A Deferred ArchivedSyncResult
""" """
if not leave_token:
stream_token = yield self.store.get_stream_token_for_event( stream_token = yield self.store.get_stream_token_for_event(
leave_event.event_id leave_event_id
) )
leave_token = since_token.copy_and_replace("room_key", stream_token) leave_token = since_token.copy_and_replace("room_key", stream_token)
if since_token.is_after(leave_token): if since_token and since_token.is_after(leave_token):
defer.returnValue(None) defer.returnValue(None)
batch = yield self.load_filtered_recents( batch = yield self.load_filtered_recents(
leave_event.room_id, sync_config, leave_token, since_token, room_id, sync_config, leave_token, since_token,
) )
logger.debug("Recents %r", batch) logger.debug("Recents %r", batch)
state_events_at_leave = yield self.store.get_state_for_event(
leave_event.event_id
)
state_at_previous_sync = yield self.get_state_at(
leave_event.room_id, stream_position=since_token
)
state_events_delta = yield self.compute_state_delta( state_events_delta = yield self.compute_state_delta(
since_token=since_token, room_id, batch, sync_config, since_token, leave_token,
previous_state=state_at_previous_sync, full_state=full_state
current_state=state_events_at_leave, )
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
account_data
) )
room_sync = ArchivedSyncResult( room_sync = ArchivedSyncResult(
room_id=leave_event.room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=state_events_delta, state=state_events_delta,
account_data=self.account_data_for_room( account_data=account_data,
leave_event.room_id, tags_by_room, account_data_by_room
),
) )
logger.debug("Room sync: %r", room_sync) logger.debug("Room sync: %r", room_sync)
@ -812,15 +853,19 @@ class SyncHandler(BaseHandler):
state = {} state = {}
defer.returnValue(state) defer.returnValue(state)
def compute_state_delta(self, since_token, previous_state, current_state): @defer.inlineCallbacks
""" Works out the differnce in state between the current state and the def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
state the client got when it last performed a sync. full_state):
""" Works out the differnce in state between the start of the timeline
and the previous sync.
:param str since_token: the point we are comparing against :param str room_id
:param dict[(str,str), synapse.events.FrozenEvent] previous_state: the :param TimelineBatch batch: The timeline batch for the room that will
state to compare to be sent to the user.
:param dict[(str,str), synapse.events.FrozenEvent] current_state: the :param sync_config
new state :param str since_token: Token of the end of the previous batch. May be None.
:param str now_token: Token of the end of the current batch.
:param bool full_state: Whether to force returning the full state.
:returns A new event dictionary :returns A new event dictionary
""" """
@ -829,12 +874,65 @@ class SyncHandler(BaseHandler):
# updates even if they occured logically before the previous event. # updates even if they occured logically before the previous event.
# TODO(mjark) Check for new redactions in the state events. # TODO(mjark) Check for new redactions in the state events.
state_delta = {} with Measure(self.clock, "compute_state_delta"):
for key, event in current_state.iteritems(): if full_state:
if (key not in previous_state or if batch:
previous_state[key].event_id != event.event_id): current_state = yield self.store.get_state_for_event(
state_delta[key] = event batch.events[-1].event_id
return state_delta )
state = yield self.store.get_state_for_event(
batch.events[0].event_id
)
else:
current_state = yield self.get_state_at(
room_id, stream_position=now_token
)
state = current_state
timeline_state = {
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
state = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state,
previous={},
current=current_state,
)
elif batch.limited:
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
current_state = yield self.store.get_state_for_event(
batch.events[-1].event_id
)
state_at_timeline_start = yield self.store.get_state_for_event(
batch.events[0].event_id
)
timeline_state = {
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
state = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state_at_timeline_start,
previous=state_at_previous_sync,
current=current_state,
)
else:
state = {}
defer.returnValue({
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values())
})
def check_joined_room(self, sync_config, state_delta): def check_joined_room(self, sync_config, state_delta):
""" """
@ -855,9 +953,12 @@ class SyncHandler(BaseHandler):
return False return False
@defer.inlineCallbacks @defer.inlineCallbacks
def unread_notifs_for_room_id(self, room_id, sync_config, ephemeral_by_room): def unread_notifs_for_room_id(self, room_id, sync_config):
last_unread_event_id = self.last_read_event_id_for_room_and_user( with Measure(self.clock, "unread_notifs_for_room_id"):
room_id, sync_config.user.to_string(), ephemeral_by_room last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
room_id=room_id,
receipt_type="m.read"
) )
notifs = [] notifs = []
@ -881,3 +982,40 @@ def _action_has_highlight(actions):
pass pass
return False return False
def _calculate_state(timeline_contains, timeline_start, previous, current):
"""Works out what state to include in a sync response.
Args:
timeline_contains (dict): state in the timeline
timeline_start (dict): state at the start of the timeline
previous (dict): state at the end of the previous sync (or empty dict
if this is an initial sync)
current (dict): state at the end of the timeline
Returns:
dict
"""
event_id_to_state = {
e.event_id: e
for e in itertools.chain(
timeline_contains.values(),
previous.values(),
timeline_start.values(),
current.values(),
)
}
c_ids = set(e.event_id for e in current.values())
tc_ids = set(e.event_id for e in timeline_contains.values())
p_ids = set(e.event_id for e in previous.values())
ts_ids = set(e.event_id for e in timeline_start.values())
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
evs = (event_id_to_state[e] for e in state_ids)
return {
(e.type, e.state_key): e
for e in evs
}

View File

@ -19,11 +19,13 @@ from ._base import BaseHandler
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.metrics import Measure
from synapse.types import UserID from synapse.types import UserID
import logging import logging
from collections import namedtuple from collections import namedtuple
import ujson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -218,10 +220,24 @@ class TypingNotificationHandler(BaseHandler):
"typing_key", self._latest_room_serial, rooms=[room_id] "typing_key", self._latest_room_serial, rooms=[room_id]
) )
def get_all_typing_updates(self, last_id, current_id):
# TODO: Work out a way to do this without scanning the entire state.
rows = []
for room_id, serial in self._room_serials.items():
if last_id < serial and serial <= current_id:
typing = self._room_typing[room_id]
typing_bytes = json.dumps([
u.to_string() for u in typing
], ensure_ascii=False)
rows.append((serial, room_id, typing_bytes))
rows.sort()
return rows
class TypingNotificationEventSource(object): class TypingNotificationEventSource(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.clock = hs.get_clock()
self._handler = None self._handler = None
self._room_member_handler = None self._room_member_handler = None
@ -247,6 +263,7 @@ class TypingNotificationEventSource(object):
} }
def get_new_events(self, from_key, room_ids, **kwargs): def get_new_events(self, from_key, room_ids, **kwargs):
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key) from_key = int(from_key)
handler = self.handler() handler = self.handler()

View File

@ -103,7 +103,7 @@ class SimpleHttpClient(object):
# TODO: Do we ever want to log message contents? # TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args) logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
response = yield self.request( response = yield self.request(
"POST", "POST",
@ -330,7 +330,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def post_urlencoded_get_raw(self, url, args={}): def post_urlencoded_get_raw(self, url, args={}):
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
response = yield self.request( response = yield self.request(
"POST", "POST",
@ -350,6 +350,19 @@ class CaptchaServerHttpClient(SimpleHttpClient):
defer.returnValue(e.response) defer.returnValue(e.response)
def encode_urlencode_args(args):
return {k: encode_urlencode_arg(v) for k, v in args.items()}
def encode_urlencode_arg(arg):
if isinstance(arg, unicode):
return arg.encode('utf-8')
elif isinstance(arg, list):
return [encode_urlencode_arg(i) for i in arg]
else:
return arg
def _print_ex(e): def _print_ex(e):
if hasattr(e, "reasons") and e.reasons: if hasattr(e, "reasons") and e.reasons:
for ex in e.reasons: for ex in e.reasons:

View File

@ -152,7 +152,7 @@ class MatrixFederationHttpClient(object):
return self.clock.time_bound_deferred( return self.clock.time_bound_deferred(
request_deferred, request_deferred,
time_out=timeout/1000. if timeout else 60, time_out=timeout / 1000. if timeout else 60,
) )
response = yield preserve_context_over_fn( response = yield preserve_context_over_fn(

View File

@ -41,7 +41,7 @@ metrics = synapse.metrics.get_metrics_for(__name__)
incoming_requests_counter = metrics.register_counter( incoming_requests_counter = metrics.register_counter(
"requests", "requests",
labels=["method", "servlet"], labels=["method", "servlet", "tag"],
) )
outgoing_responses_counter = metrics.register_counter( outgoing_responses_counter = metrics.register_counter(
"responses", "responses",
@ -50,23 +50,23 @@ outgoing_responses_counter = metrics.register_counter(
response_timer = metrics.register_distribution( response_timer = metrics.register_distribution(
"response_time", "response_time",
labels=["method", "servlet"] labels=["method", "servlet", "tag"]
) )
response_ru_utime = metrics.register_distribution( response_ru_utime = metrics.register_distribution(
"response_ru_utime", labels=["method", "servlet"] "response_ru_utime", labels=["method", "servlet", "tag"]
) )
response_ru_stime = metrics.register_distribution( response_ru_stime = metrics.register_distribution(
"response_ru_stime", labels=["method", "servlet"] "response_ru_stime", labels=["method", "servlet", "tag"]
) )
response_db_txn_count = metrics.register_distribution( response_db_txn_count = metrics.register_distribution(
"response_db_txn_count", labels=["method", "servlet"] "response_db_txn_count", labels=["method", "servlet", "tag"]
) )
response_db_txn_duration = metrics.register_distribution( response_db_txn_duration = metrics.register_distribution(
"response_db_txn_duration", labels=["method", "servlet"] "response_db_txn_duration", labels=["method", "servlet", "tag"]
) )
@ -99,9 +99,8 @@ def request_handler(request_handler):
request_context.request = request_id request_context.request = request_id
with request.processing(): with request.processing():
try: try:
d = request_handler(self, request) with PreserveLoggingContext(request_context):
with PreserveLoggingContext(): yield request_handler(self, request)
yield d
except CodeMessageException as e: except CodeMessageException as e:
code = e.code code = e.code
if isinstance(e, SynapseError): if isinstance(e, SynapseError):
@ -208,6 +207,9 @@ class JsonResource(HttpServer, resource.Resource):
if request.method == "OPTIONS": if request.method == "OPTIONS":
self._send_response(request, 200, {}) self._send_response(request, 200, {})
return return
start_context = LoggingContext.current_context()
# Loop through all the registered callbacks to check if the method # Loop through all the registered callbacks to check if the method
# and path regex match # and path regex match
for path_entry in self.path_regexs.get(request.method, []): for path_entry in self.path_regexs.get(request.method, []):
@ -226,7 +228,6 @@ class JsonResource(HttpServer, resource.Resource):
servlet_classname = servlet_instance.__class__.__name__ servlet_classname = servlet_instance.__class__.__name__
else: else:
servlet_classname = "%r" % callback servlet_classname = "%r" % callback
incoming_requests_counter.inc(request.method, servlet_classname)
args = [ args = [
urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups() urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups()
@ -237,21 +238,40 @@ class JsonResource(HttpServer, resource.Resource):
code, response = callback_return code, response = callback_return
self._send_response(request, code, response) self._send_response(request, code, response)
response_timer.inc_by(
self.clock.time_msec() - start, request.method, servlet_classname
)
try: try:
context = LoggingContext.current_context() context = LoggingContext.current_context()
tag = ""
if context:
tag = context.tag
if context != start_context:
logger.warn(
"Context have unexpectedly changed %r, %r",
context, self.start_context
)
return
incoming_requests_counter.inc(request.method, servlet_classname, tag)
response_timer.inc_by(
self.clock.time_msec() - start, request.method,
servlet_classname, tag
)
ru_utime, ru_stime = context.get_resource_usage() ru_utime, ru_stime = context.get_resource_usage()
response_ru_utime.inc_by(ru_utime, request.method, servlet_classname) response_ru_utime.inc_by(
response_ru_stime.inc_by(ru_stime, request.method, servlet_classname) ru_utime, request.method, servlet_classname, tag
)
response_ru_stime.inc_by(
ru_stime, request.method, servlet_classname, tag
)
response_db_txn_count.inc_by( response_db_txn_count.inc_by(
context.db_txn_count, request.method, servlet_classname context.db_txn_count, request.method, servlet_classname, tag
) )
response_db_txn_duration.inc_by( response_db_txn_duration.inc_by(
context.db_txn_duration, request.method, servlet_classname context.db_txn_duration, request.method, servlet_classname, tag
) )
except: except:
pass pass
@ -347,10 +367,29 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
"Origin, X-Requested-With, Content-Type, Accept") "Origin, X-Requested-With, Content-Type, Accept")
request.write(json_bytes) request.write(json_bytes)
request.finish() finish_request(request)
return NOT_DONE_YET return NOT_DONE_YET
def finish_request(request):
""" Finish writing the response to the request.
Twisted throws a RuntimeException if the connection closed before the
response was written but doesn't provide a convenient or reliable way to
determine if the connection was closed. So we catch and log the RuntimeException
You might think that ``request.notifyFinish`` could be used to tell if the
request was finished. However the deferred it returns won't fire if the
connection was already closed, meaning we'd have to have called the method
right at the start of the request. By the time we want to write the response
it will already be too late.
"""
try:
request.finish()
except RuntimeError as e:
logger.info("Connection disconnected before response was written: %r", e)
def _request_user_agent_is_curl(request): def _request_user_agent_is_curl(request):
user_agents = request.requestHeaders.getRawHeaders( user_agents = request.requestHeaders.getRawHeaders(
"User-Agent", default=[] "User-Agent", default=[]

View File

@ -15,14 +15,27 @@
""" This module contains base REST classes for constructing REST servlets. """ """ This module contains base REST classes for constructing REST servlets. """
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, Codes
import logging import logging
import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def parse_integer(request, name, default=None, required=False): def parse_integer(request, name, default=None, required=False):
"""Parse an integer parameter from the request string
:param request: the twisted HTTP request.
:param name (str): the name of the query parameter.
:param default: value to use if the parameter is absent, defaults to None.
:param required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
:return: An int value or the default.
:raises
SynapseError if the parameter is absent and required, or if the
parameter is present and not an integer.
"""
if name in request.args: if name in request.args:
try: try:
return int(request.args[name][0]) return int(request.args[name][0])
@ -32,12 +45,25 @@ def parse_integer(request, name, default=None, required=False):
else: else:
if required: if required:
message = "Missing integer query parameter %r" % (name,) message = "Missing integer query parameter %r" % (name,)
raise SynapseError(400, message) raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else: else:
return default return default
def parse_boolean(request, name, default=None, required=False): def parse_boolean(request, name, default=None, required=False):
"""Parse a boolean parameter from the request query string
:param request: the twisted HTTP request.
:param name (str): the name of the query parameter.
:param default: value to use if the parameter is absent, defaults to None.
:param required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
:return: A bool value or the default.
:raises
SynapseError if the parameter is absent and required, or if the
parameter is present and not one of "true" or "false".
"""
if name in request.args: if name in request.args:
try: try:
return { return {
@ -53,30 +79,84 @@ def parse_boolean(request, name, default=None, required=False):
else: else:
if required: if required:
message = "Missing boolean query parameter %r" % (name,) message = "Missing boolean query parameter %r" % (name,)
raise SynapseError(400, message) raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else: else:
return default return default
def parse_string(request, name, default=None, required=False, def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"): allowed_values=None, param_type="string"):
"""Parse a string parameter from the request query string.
:param request: the twisted HTTP request.
:param name (str): the name of the query parameter.
:param default: value to use if the parameter is absent, defaults to None.
:param required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
:param allowed_values (list): List of allowed values for the string,
or None if any value is allowed, defaults to None
:return: A string value or the default.
:raises
SynapseError if the parameter is absent and required, or if the
parameter is present, must be one of a list of allowed values and
is not one of those allowed values.
"""
if name in request.args: if name in request.args:
value = request.args[name][0] value = request.args[name][0]
if allowed_values is not None and value not in allowed_values: if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % ( message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values) name, ", ".join(repr(v) for v in allowed_values)
) )
raise SynapseError(message) raise SynapseError(400, message)
else: else:
return value return value
else: else:
if required: if required:
message = "Missing %s query parameter %r" % (param_type, name) message = "Missing %s query parameter %r" % (param_type, name)
raise SynapseError(400, message) raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else: else:
return default return default
def parse_json_value_from_request(request):
"""Parse a JSON value from the body of a twisted HTTP request.
:param request: the twisted HTTP request.
:returns: The JSON value.
:raises
SynapseError if the request body couldn't be decoded as JSON.
"""
try:
content_bytes = request.content.read()
except:
raise SynapseError(400, "Error reading JSON content.")
try:
content = simplejson.loads(content_bytes)
except simplejson.JSONDecodeError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
return content
def parse_json_object_from_request(request):
"""Parse a JSON object from the body of a twisted HTTP request.
:param request: the twisted HTTP request.
:raises
SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object.
"""
content = parse_json_value_from_request(request)
if type(content) != dict:
message = "Content must be a JSON object."
raise SynapseError(400, message, errcode=Codes.BAD_JSON)
return content
class RestServlet(object): class RestServlet(object):
""" A Synapse REST Servlet. """ A Synapse REST Servlet.

View File

@ -18,10 +18,13 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor, ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import StreamToken from synapse.types import StreamToken
import synapse.metrics import synapse.metrics
from collections import namedtuple
import logging import logging
@ -71,6 +74,7 @@ class _NotifierUserStream(object):
self.current_token = current_token self.current_token = current_token
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred()) self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key, stream_id, time_now_ms): def notify(self, stream_key, stream_id, time_now_ms):
@ -86,6 +90,8 @@ class _NotifierUserStream(object):
) )
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred noify_deferred = self.notify_deferred
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred()) self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token) noify_deferred.callback(self.current_token)
@ -118,6 +124,11 @@ class _NotifierUserStream(object):
return _NotificationListener(self.notify_deferred.observe()) return _NotificationListener(self.notify_deferred.observe())
class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
def __nonzero__(self):
return bool(self.events)
class Notifier(object): class Notifier(object):
""" This class is responsible for notifying any listeners when there are """ This class is responsible for notifying any listeners when there are
new events available for it. new events available for it.
@ -148,6 +159,8 @@ class Notifier(object):
self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
) )
self.replication_deferred = ObservableDeferred(defer.Deferred())
# This is not a very cheap test to perform, but it's only executed # This is not a very cheap test to perform, but it's only executed
# when rendering the metrics page, which is likely once per minute at # when rendering the metrics page, which is likely once per minute at
# most when scraping it. # most when scraping it.
@ -177,8 +190,6 @@ class Notifier(object):
lambda: count(bool, self.appservice_to_user_streams.values()), lambda: count(bool, self.appservice_to_user_streams.values()),
) )
@log_function
@defer.inlineCallbacks
def on_new_room_event(self, event, room_stream_id, max_room_stream_id, def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]): extra_users=[]):
""" Used by handlers to inform the notifier something has happened """ Used by handlers to inform the notifier something has happened
@ -192,13 +203,14 @@ class Notifier(object):
until all previous events have been persisted before notifying until all previous events have been persisted before notifying
the client streams. the client streams.
""" """
yield run_on_reactor() with PreserveLoggingContext():
self.pending_new_room_events.append(( self.pending_new_room_events.append((
room_stream_id, event, extra_users room_stream_id, event, extra_users
)) ))
self._notify_pending_new_room_events(max_room_stream_id) self._notify_pending_new_room_events(max_room_stream_id)
self.notify_replication()
def _notify_pending_new_room_events(self, max_room_stream_id): def _notify_pending_new_room_events(self, max_room_stream_id):
"""Notify for the room events that were queued waiting for a previous """Notify for the room events that were queued waiting for a previous
event to be persisted. event to be persisted.
@ -244,15 +256,13 @@ class Notifier(object):
extra_streams=app_streams, extra_streams=app_streams,
) )
@defer.inlineCallbacks
@log_function
def on_new_event(self, stream_key, new_token, users=[], rooms=[], def on_new_event(self, stream_key, new_token, users=[], rooms=[],
extra_streams=set()): extra_streams=set()):
""" Used to inform listeners that something has happend event wise. """ Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
yield run_on_reactor() with PreserveLoggingContext():
user_streams = set() user_streams = set()
for user in users: for user in users:
@ -270,9 +280,17 @@ class Notifier(object):
except: except:
logger.exception("Failed to notify listener") logger.exception("Failed to notify listener")
self.notify_replication()
def on_new_replication_data(self):
"""Used to inform replication listeners that something has happend
without waking up any of the normal user event streams"""
with PreserveLoggingContext():
self.notify_replication()
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user_id, timeout, callback, room_ids=None, def wait_for_events(self, user_id, timeout, callback, room_ids=None,
from_token=StreamToken("s0", "0", "0", "0", "0")): from_token=StreamToken.START):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """
@ -301,7 +319,7 @@ class Notifier(object):
def timed_out(): def timed_out():
if listener: if listener:
listener.deferred.cancel() listener.deferred.cancel()
timer = self.clock.call_later(timeout/1000., timed_out) timer = self.clock.call_later(timeout / 1000., timed_out)
prev_token = from_token prev_token = from_token
while not result: while not result:
@ -318,6 +336,7 @@ class Notifier(object):
# that we don't miss any current_token updates. # that we don't miss any current_token updates.
prev_token = current_token prev_token = current_token
listener = user_stream.new_listener(prev_token) listener = user_stream.new_listener(prev_token)
with PreserveLoggingContext():
yield listener.deferred yield listener.deferred
except defer.CancelledError: except defer.CancelledError:
break break
@ -356,7 +375,7 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_for_updates(before_token, after_token): def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token): if not after_token.is_after(before_token):
defer.returnValue(None) defer.returnValue(EventStreamResult([], (from_token, from_token)))
events = [] events = []
end_token = from_token end_token = from_token
@ -369,6 +388,7 @@ class Notifier(object):
continue continue
if only_keys and name not in only_keys: if only_keys and name not in only_keys:
continue continue
new_events, new_key = yield source.get_new_events( new_events, new_key = yield source.get_new_events(
user=user, user=user,
from_key=getattr(from_token, keyname), from_key=getattr(from_token, keyname),
@ -388,10 +408,7 @@ class Notifier(object):
events.extend(new_events) events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key) end_token = end_token.copy_and_replace(keyname, new_key)
if events: defer.returnValue(EventStreamResult(events, (from_token, end_token)))
defer.returnValue((events, (from_token, end_token)))
else:
defer.returnValue(None)
user_id_for_stream = user.to_string() user_id_for_stream = user.to_string()
if is_peeking: if is_peeking:
@ -415,9 +432,6 @@ class Notifier(object):
from_token=from_token, from_token=from_token,
) )
if result is None:
result = ([], (from_token, from_token))
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -477,3 +491,45 @@ class Notifier(object):
room_streams = self.room_to_user_streams.setdefault(room_id, set()) room_streams = self.room_to_user_streams.setdefault(room_id, set())
room_streams.add(new_user_stream) room_streams.add(new_user_stream)
new_user_stream.rooms.add(room_id) new_user_stream.rooms.add(room_id)
def notify_replication(self):
"""Notify the any replication listeners that there's a new event"""
with PreserveLoggingContext():
deferred = self.replication_deferred
self.replication_deferred = ObservableDeferred(defer.Deferred())
deferred.callback(None)
@defer.inlineCallbacks
def wait_for_replication(self, callback, timeout):
"""Wait for an event to happen.
:param callback:
Gets called whenever an event happens. If this returns a truthy
value then ``wait_for_replication`` returns, otherwise it waits
for another event.
:param int timeout:
How many milliseconds to wait for callback return a truthy value.
:returns:
A deferred that resolves with the value returned by the callback.
"""
listener = _NotificationListener(None)
def timed_out():
listener.deferred.cancel()
timer = self.clock.call_later(timeout / 1000., timed_out)
while True:
listener.deferred = self.replication_deferred.observe()
result = yield callback()
if result:
break
try:
with PreserveLoggingContext():
yield listener.deferred
except defer.CancelledError:
break
self.clock.cancel_call_later(timer, ignore_errs=True)
defer.returnValue(result)

View File

@ -17,10 +17,11 @@ from twisted.internet import defer
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.api.constants import Membership from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure
import synapse.util.async import synapse.util.async
import push_rule_evaluator as push_rule_evaluator from .push_rule_evaluator import evaluator_for_user_id
import logging import logging
import random import random
@ -28,6 +29,16 @@ import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_NEXT_ID = 1
def _get_next_id():
global _NEXT_ID
_id = _NEXT_ID
_NEXT_ID += 1
return _id
# Pushers could now be moved to pull out of the event_push_actions table instead # Pushers could now be moved to pull out of the event_push_actions table instead
# of listening on the event stream: this would avoid them having to run the # of listening on the event stream: this would avoid them having to run the
# rules again. # rules again.
@ -36,14 +47,13 @@ class Pusher(object):
MAX_BACKOFF = 60 * 60 * 1000 MAX_BACKOFF = 60 * 60 * 1000
GIVE_UP_AFTER = 24 * 60 * 60 * 1000 GIVE_UP_AFTER = 24 * 60 * 60 * 1000
def __init__(self, _hs, profile_tag, user_id, app_id, def __init__(self, _hs, user_id, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts, app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since): data, last_token, last_success, failing_since):
self.hs = _hs self.hs = _hs
self.evStreamHandler = self.hs.get_handlers().event_stream_handler self.evStreamHandler = self.hs.get_handlers().event_stream_handler
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.profile_tag = profile_tag
self.user_id = user_id self.user_id = user_id
self.app_id = app_id self.app_id = app_id
self.app_display_name = app_display_name self.app_display_name = app_display_name
@ -58,6 +68,8 @@ class Pusher(object):
self.alive = True self.alive = True
self.badge = None self.badge = None
self.name = "Pusher-%d" % (_get_next_id(),)
# The last value of last_active_time that we saw # The last value of last_active_time that we saw
self.last_last_active_time = 0 self.last_last_active_time = 0
self.has_unread = True self.has_unread = True
@ -87,6 +99,7 @@ class Pusher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
with LoggingContext(self.name):
if not self.last_token: if not self.last_token:
# First-time setup: get a token to start from (we can't # First-time setup: get a token to start from (we can't
# just start from no token, ie. 'now' # just start from no token, ie. 'now'
@ -97,17 +110,24 @@ class Pusher(object):
self.user_id, config, timeout=0, affect_presence=False self.user_id, config, timeout=0, affect_presence=False
) )
self.last_token = chunk['end'] self.last_token = chunk['end']
self.store.update_pusher_last_token( yield self.store.update_pusher_last_token(
self.app_id, self.pushkey, self.user_id, self.last_token self.app_id, self.pushkey, self.user_id, self.last_token
) )
logger.info("Pusher %s for user %s starting from token %s", logger.info("New pusher %s for user %s starting from token %s",
self.pushkey, self.user_id, self.last_token) self.pushkey, self.user_id, self.last_token)
else:
logger.info(
"Old pusher %s for user %s starting",
self.pushkey, self.user_id,
)
wait = 0 wait = 0
while self.alive: while self.alive:
try: try:
if wait > 0: if wait > 0:
yield synapse.util.async.sleep(wait) yield synapse.util.async.sleep(wait)
with Measure(self.clock, "push"):
yield self.get_and_dispatch() yield self.get_and_dispatch()
wait = 0 wait = 0
except: except:
@ -165,8 +185,8 @@ class Pusher(object):
processed = False processed = False
rule_evaluator = yield \ rule_evaluator = yield \
push_rule_evaluator.evaluator_for_user_id_and_profile_tag( evaluator_for_user_id(
self.user_id, self.profile_tag, single_event['room_id'], self.store self.user_id, single_event['room_id'], self.store
) )
actions = yield rule_evaluator.actions_for_event(single_event) actions = yield rule_evaluator.actions_for_event(single_event)
@ -296,22 +316,19 @@ class Pusher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_badge_count(self): def _get_badge_count(self):
room_list = yield self.store.get_rooms_for_user_where_membership_is( invites, joins = yield defer.gatherResults([
user_id=self.user_id, self.store.get_invites_for_user(self.user_id),
membership_list=(Membership.INVITE, Membership.JOIN) self.store.get_rooms_for_user(self.user_id),
) ], consumeErrors=True)
my_receipts_by_room = yield self.store.get_receipts_for_user( my_receipts_by_room = yield self.store.get_receipts_for_user(
self.user_id, self.user_id,
"m.read", "m.read",
) )
badge = 0 badge = len(invites)
for r in room_list: for r in joins:
if r.membership == Membership.INVITE:
badge += 1
else:
if r.room_id in my_receipts_by_room: if r.room_id in my_receipts_by_room:
last_unread_event_id = my_receipts_by_room[r.room_id] last_unread_event_id = my_receipts_by_room[r.room_id]
@ -320,7 +337,7 @@ class Pusher(object):
r.room_id, self.user_id, last_unread_event_id r.room_id, self.user_id, last_unread_event_id
) )
) )
badge += len(notifs) badge += notifs["notify_count"]
defer.returnValue(badge) defer.returnValue(badge)

View File

@ -15,12 +15,10 @@
from twisted.internet import defer from twisted.internet import defer
import bulk_push_rule_evaluator from .bulk_push_rule_evaluator import evaluator_for_room_id
import logging import logging
from synapse.api.constants import EventTypes
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,21 +34,15 @@ class ActionGenerator:
# tag (ie. we just need all the users). # tag (ie. we just need all the users).
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_push_actions_for_event(self, event, handler): def handle_push_actions_for_event(self, event, context, handler):
if event.type == EventTypes.Redaction and event.redacts is not None: bulk_evaluator = yield evaluator_for_room_id(
yield self.store.remove_push_actions_for_event_id(
event.room_id, event.redacts
)
bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id(
event.room_id, self.hs, self.store event.room_id, self.hs, self.store
) )
actions_by_user = yield bulk_evaluator.action_for_event_by_user(event, handler) actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, handler, context.current_state
yield self.store.set_push_actions_for_event_and_users(
event,
[
(uid, None, actions) for uid, actions in actions_by_user.items()
]
) )
context.push_actions = [
(uid, actions) for uid, actions in actions_by_user.items()
]

View File

@ -13,46 +13,67 @@
# limitations under the License. # limitations under the License.
from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
import copy
def list_with_base_rules(rawrules): def list_with_base_rules(rawrules):
"""Combine the list of rules set by the user with the default push rules
:param list rawrules: The rules the user has modified or set.
:returns: A new list with the rules set by the user combined with the
defaults.
"""
ruleslist = [] ruleslist = []
# Grab the base rules that the user has modified.
# The modified base rules have a priority_class of -1.
modified_base_rules = {
r['rule_id']: r for r in rawrules if r['priority_class'] < 0
}
# Remove the modified base rules from the list, They'll be added back
# in the default postions in the list.
rawrules = [r for r in rawrules if r['priority_class'] >= 0]
# shove the server default rules for each kind onto the end of each # shove the server default rules for each kind onto the end of each
current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1] current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
ruleslist.extend(make_base_prepend_rules( ruleslist.extend(make_base_prepend_rules(
PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
)) ))
for r in rawrules: for r in rawrules:
if r['priority_class'] < current_prio_class: if r['priority_class'] < current_prio_class:
while r['priority_class'] < current_prio_class: while r['priority_class'] < current_prio_class:
ruleslist.extend(make_base_append_rules( ruleslist.extend(make_base_append_rules(
PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
modified_base_rules,
)) ))
current_prio_class -= 1 current_prio_class -= 1
if current_prio_class > 0: if current_prio_class > 0:
ruleslist.extend(make_base_prepend_rules( ruleslist.extend(make_base_prepend_rules(
PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
modified_base_rules,
)) ))
ruleslist.append(r) ruleslist.append(r)
while current_prio_class > 0: while current_prio_class > 0:
ruleslist.extend(make_base_append_rules( ruleslist.extend(make_base_append_rules(
PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
modified_base_rules,
)) ))
current_prio_class -= 1 current_prio_class -= 1
if current_prio_class > 0: if current_prio_class > 0:
ruleslist.extend(make_base_prepend_rules( ruleslist.extend(make_base_prepend_rules(
PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
modified_base_rules,
)) ))
return ruleslist return ruleslist
def make_base_append_rules(kind): def make_base_append_rules(kind, modified_base_rules):
rules = [] rules = []
if kind == 'override': if kind == 'override':
@ -62,15 +83,31 @@ def make_base_append_rules(kind):
elif kind == 'content': elif kind == 'content':
rules = BASE_APPEND_CONTENT_RULES rules = BASE_APPEND_CONTENT_RULES
# Copy the rules before modifying them
rules = copy.deepcopy(rules)
for r in rules:
# Only modify the actions, keep the conditions the same.
modified = modified_base_rules.get(r['rule_id'])
if modified:
r['actions'] = modified['actions']
return rules return rules
def make_base_prepend_rules(kind): def make_base_prepend_rules(kind, modified_base_rules):
rules = [] rules = []
if kind == 'override': if kind == 'override':
rules = BASE_PREPEND_OVERRIDE_RULES rules = BASE_PREPEND_OVERRIDE_RULES
# Copy the rules before modifying them
rules = copy.deepcopy(rules)
for r in rules:
# Only modify the actions, keep the conditions the same.
modified = modified_base_rules.get(r['rule_id'])
if modified:
r['actions'] = modified['actions']
return rules return rules
@ -173,6 +210,12 @@ BASE_APPEND_UNDERRIDE_RULES = [
'kind': 'room_member_count', 'kind': 'room_member_count',
'is': '2', 'is': '2',
'_id': 'member_count', '_id': 'member_count',
},
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.message',
'_id': '_message',
} }
], ],
'actions': [ 'actions': [
@ -257,18 +300,24 @@ BASE_APPEND_UNDERRIDE_RULES = [
] ]
BASE_RULE_IDS = set()
for r in BASE_APPEND_CONTENT_RULES: for r in BASE_APPEND_CONTENT_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['content'] r['priority_class'] = PRIORITY_CLASS_MAP['content']
r['default'] = True r['default'] = True
BASE_RULE_IDS.add(r['rule_id'])
for r in BASE_PREPEND_OVERRIDE_RULES: for r in BASE_PREPEND_OVERRIDE_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['override'] r['priority_class'] = PRIORITY_CLASS_MAP['override']
r['default'] = True r['default'] = True
BASE_RULE_IDS.add(r['rule_id'])
for r in BASE_APPEND_OVRRIDE_RULES: for r in BASE_APPEND_OVRRIDE_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['override'] r['priority_class'] = PRIORITY_CLASS_MAP['override']
r['default'] = True r['default'] = True
BASE_RULE_IDS.add(r['rule_id'])
for r in BASE_APPEND_UNDERRIDE_RULES: for r in BASE_APPEND_UNDERRIDE_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['underride'] r['priority_class'] = PRIORITY_CLASS_MAP['underride']
r['default'] = True r['default'] = True
BASE_RULE_IDS.add(r['rule_id'])

View File

@ -18,8 +18,8 @@ import ujson as json
from twisted.internet import defer from twisted.internet import defer
import baserules from .baserules import list_with_base_rules
from push_rule_evaluator import PushRuleEvaluatorForEvent from .push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@ -39,7 +39,7 @@ def _get_rules(room_id, user_ids, store):
rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids) rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids)
rules_by_user = { rules_by_user = {
uid: baserules.list_with_base_rules([ uid: list_with_base_rules([
decode_rule_json(rule_list) decode_rule_json(rule_list)
for rule_list in rules_by_user.get(uid, []) for rule_list in rules_by_user.get(uid, [])
]) ])
@ -98,25 +98,21 @@ class BulkPushRuleEvaluator:
self.store = store self.store = store
@defer.inlineCallbacks @defer.inlineCallbacks
def action_for_event_by_user(self, event, handler): def action_for_event_by_user(self, event, handler, current_state):
actions_by_user = {} actions_by_user = {}
users_dict = yield self.store.are_guests(self.rules_by_user.keys()) users_dict = yield self.store.are_guests(self.rules_by_user.keys())
filtered_by_user = yield handler._filter_events_for_clients( filtered_by_user = yield handler.filter_events_for_clients(
users_dict.items(), [event] users_dict.items(), [event], {event.event_id: current_state}
) )
evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room)) evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room))
condition_cache = {} condition_cache = {}
member_state = yield self.store.get_state_for_event(
event.event_id,
)
display_names = {} display_names = {}
for ev in member_state.values(): for ev in current_state.values():
nm = ev.content.get("displayname", None) nm = ev.content.get("displayname", None)
if nm and ev.type == EventTypes.Member: if nm and ev.type == EventTypes.Member:
display_names[ev.state_key] = nm display_names[ev.state_key] = nm
@ -156,7 +152,7 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache):
elif res is True: elif res is True:
continue continue
res = evaluator.matches(cond, uid, display_name, None) res = evaluator.matches(cond, uid, display_name)
if _id: if _id:
cache[_id] = bool(res) cache[_id] = bool(res)

View File

@ -0,0 +1,112 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.push.baserules import list_with_base_rules
from synapse.push.rulekinds import (
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
)
import copy
import simplejson as json
def format_push_rules_for_user(user, rawrules, enabled_map):
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(list_with_base_rules(ruleslist))
rules = {'global': {}, 'device': {}}
rules['global'] = _add_empty_priority_class_arrays(rules['global'])
for r in ruleslist:
rulearray = None
template_name = _priority_class_to_template_name(r['priority_class'])
# Remove internal stuff.
for c in r["conditions"]:
c.pop("_id", None)
pattern_type = c.pop("pattern_type", None)
if pattern_type == "user_id":
c["pattern"] = user.to_string()
elif pattern_type == "user_localpart":
c["pattern"] = user.localpart
rulearray = rules['global'][template_name]
template_rule = _rule_to_template(r)
if template_rule:
if r['rule_id'] in enabled_map:
template_rule['enabled'] = enabled_map[r['rule_id']]
elif 'enabled' in r:
template_rule['enabled'] = r['enabled']
else:
template_rule['enabled'] = True
rulearray.append(template_rule)
return rules
def _add_empty_priority_class_arrays(d):
for pc in PRIORITY_CLASS_MAP.keys():
d[pc] = []
return d
def _rule_to_template(rule):
unscoped_rule_id = None
if 'rule_id' in rule:
unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id'])
template_name = _priority_class_to_template_name(rule['priority_class'])
if template_name in ['override', 'underride']:
templaterule = {k: rule[k] for k in ["conditions", "actions"]}
elif template_name in ["sender", "room"]:
templaterule = {'actions': rule['actions']}
unscoped_rule_id = rule['conditions'][0]['pattern']
elif template_name == 'content':
if len(rule["conditions"]) != 1:
return None
thecond = rule["conditions"][0]
if "pattern" not in thecond:
return None
templaterule = {'actions': rule['actions']}
templaterule["pattern"] = thecond["pattern"]
if unscoped_rule_id:
templaterule['rule_id'] = unscoped_rule_id
if 'default' in rule:
templaterule['default'] = rule['default']
return templaterule
def _rule_id_from_namespaced(in_rule_id):
return in_rule_id.split('/')[-1]
def _priority_class_to_template_name(pc):
return PRIORITY_CLASS_INVERSE_MAP[pc]

View File

@ -23,12 +23,11 @@ logger = logging.getLogger(__name__)
class HttpPusher(Pusher): class HttpPusher(Pusher):
def __init__(self, _hs, profile_tag, user_id, app_id, def __init__(self, _hs, user_id, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts, app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since): data, last_token, last_success, failing_since):
super(HttpPusher, self).__init__( super(HttpPusher, self).__init__(
_hs, _hs,
profile_tag,
user_id, user_id,
app_id, app_id,
app_display_name, app_display_name,

View File

@ -15,7 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
import baserules from .baserules import list_with_base_rules
import logging import logging
import simplejson as json import simplejson as json
@ -33,7 +33,7 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store): def evaluator_for_user_id(user_id, room_id, store):
rawrules = yield store.get_push_rules_for_user(user_id) rawrules = yield store.get_push_rules_for_user(user_id)
enabled_map = yield store.get_push_rules_enabled_for_user(user_id) enabled_map = yield store.get_push_rules_enabled_for_user(user_id)
our_member_event = yield store.get_current_state( our_member_event = yield store.get_current_state(
@ -43,7 +43,7 @@ def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
) )
defer.returnValue(PushRuleEvaluator( defer.returnValue(PushRuleEvaluator(
user_id, profile_tag, rawrules, enabled_map, user_id, rawrules, enabled_map,
room_id, our_member_event, store room_id, our_member_event, store
)) ))
@ -77,10 +77,9 @@ def _room_member_count(ev, condition, room_member_count):
class PushRuleEvaluator: class PushRuleEvaluator:
DEFAULT_ACTIONS = [] DEFAULT_ACTIONS = []
def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id, def __init__(self, user_id, raw_rules, enabled_map, room_id,
our_member_event, store): our_member_event, store):
self.user_id = user_id self.user_id = user_id
self.profile_tag = profile_tag
self.room_id = room_id self.room_id = room_id
self.our_member_event = our_member_event self.our_member_event = our_member_event
self.store = store self.store = store
@ -92,7 +91,7 @@ class PushRuleEvaluator:
rule['actions'] = json.loads(raw_rule['actions']) rule['actions'] = json.loads(raw_rule['actions'])
rules.append(rule) rules.append(rule)
self.rules = baserules.list_with_base_rules(rules) self.rules = list_with_base_rules(rules)
self.enabled_map = enabled_map self.enabled_map = enabled_map
@ -152,7 +151,7 @@ class PushRuleEvaluator:
matches = True matches = True
for c in conditions: for c in conditions:
matches = evaluator.matches( matches = evaluator.matches(
c, self.user_id, my_display_name, self.profile_tag c, self.user_id, my_display_name
) )
if not matches: if not matches:
break break
@ -189,13 +188,9 @@ class PushRuleEvaluatorForEvent(object):
# Maps strings of e.g. 'content.body' -> event["content"]["body"] # Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event) self._value_cache = _flatten_dict(event)
def matches(self, condition, user_id, display_name, profile_tag): def matches(self, condition, user_id, display_name):
if condition['kind'] == 'event_match': if condition['kind'] == 'event_match':
return self._event_match(condition, user_id) return self._event_match(condition, user_id)
elif condition['kind'] == 'device':
if 'profile_tag' not in condition:
return True
return condition['profile_tag'] == profile_tag
elif condition['kind'] == 'contains_display_name': elif condition['kind'] == 'contains_display_name':
return self._contains_display_name(display_name) return self._contains_display_name(display_name)
elif condition['kind'] == 'room_member_count': elif condition['kind'] == 'room_member_count':
@ -304,7 +299,7 @@ def _flatten_dict(d, prefix=[], result={}):
if isinstance(value, basestring): if isinstance(value, basestring):
result[".".join(prefix + [key])] = value.lower() result[".".join(prefix + [key])] = value.lower()
elif hasattr(value, "items"): elif hasattr(value, "items"):
_flatten_dict(value, prefix=(prefix+[key]), result=result) _flatten_dict(value, prefix=(prefix + [key]), result=result)
return result return result

View File

@ -16,8 +16,9 @@
from twisted.internet import defer from twisted.internet import defer
from httppusher import HttpPusher from .httppusher import HttpPusher
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.util.logcontext import preserve_fn
import logging import logging
@ -28,6 +29,7 @@ class PusherPool:
def __init__(self, _hs): def __init__(self, _hs):
self.hs = _hs self.hs = _hs
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.pushers = {} self.pushers = {}
self.last_pusher_started = -1 self.last_pusher_started = -1
@ -37,8 +39,11 @@ class PusherPool:
self._start_pushers(pushers) self._start_pushers(pushers)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_pusher(self, user_id, access_token, profile_tag, kind, app_id, def add_pusher(self, user_id, access_token, kind, app_id,
app_display_name, device_display_name, pushkey, lang, data): app_display_name, device_display_name, pushkey, lang, data,
profile_tag=""):
time_now_msec = self.clock.time_msec()
# we try to create the pusher just to validate the config: it # we try to create the pusher just to validate the config: it
# will then get pulled out of the database, # will then get pulled out of the database,
# recreated, added and started: this means we have only one # recreated, added and started: this means we have only one
@ -46,23 +51,31 @@ class PusherPool:
self._create_pusher({ self._create_pusher({
"user_name": user_id, "user_name": user_id,
"kind": kind, "kind": kind,
"profile_tag": profile_tag,
"app_id": app_id, "app_id": app_id,
"app_display_name": app_display_name, "app_display_name": app_display_name,
"device_display_name": device_display_name, "device_display_name": device_display_name,
"pushkey": pushkey, "pushkey": pushkey,
"ts": self.hs.get_clock().time_msec(), "ts": time_now_msec,
"lang": lang, "lang": lang,
"data": data, "data": data,
"last_token": None, "last_token": None,
"last_success": None, "last_success": None,
"failing_since": None "failing_since": None
}) })
yield self._add_pusher_to_store( yield self.store.add_pusher(
user_id, access_token, profile_tag, kind, app_id, user_id=user_id,
app_display_name, device_display_name, access_token=access_token,
pushkey, lang, data kind=kind,
app_id=app_id,
app_display_name=app_display_name,
device_display_name=device_display_name,
pushkey=pushkey,
pushkey_ts=time_now_msec,
lang=lang,
data=data,
profile_tag=profile_tag,
) )
yield self._refresh_pusher(app_id, pushkey, user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey, def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey,
@ -76,47 +89,27 @@ class PusherPool:
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
app_id, pushkey, p['user_name'] app_id, pushkey, p['user_name']
) )
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_user(self, user_id): def remove_pushers_by_user(self, user_id, except_token_ids=[]):
all = yield self.store.get_all_pushers() all = yield self.store.get_all_pushers()
logger.info( logger.info(
"Removing all pushers for user %s", "Removing all pushers for user %s except access tokens ids %r",
user_id, user_id, except_token_ids
) )
for p in all: for p in all:
if p['user_name'] == user_id: if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name'] p['app_id'], p['pushkey'], p['user_name']
) )
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, lang, data):
yield self.store.add_pusher(
user_id=user_id,
access_token=access_token,
profile_tag=profile_tag,
kind=kind,
app_id=app_id,
app_display_name=app_display_name,
device_display_name=device_display_name,
pushkey=pushkey,
pushkey_ts=self.hs.get_clock().time_msec(),
lang=lang,
data=data,
)
self._refresh_pusher(app_id, pushkey, user_id)
def _create_pusher(self, pusherdict): def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http': if pusherdict['kind'] == 'http':
return HttpPusher( return HttpPusher(
self.hs, self.hs,
profile_tag=pusherdict['profile_tag'],
user_id=pusherdict['user_name'], user_id=pusherdict['user_name'],
app_id=pusherdict['app_id'], app_id=pusherdict['app_id'],
app_display_name=pusherdict['app_display_name'], app_display_name=pusherdict['app_display_name'],
@ -166,7 +159,7 @@ class PusherPool:
if fullid in self.pushers: if fullid in self.pushers:
self.pushers[fullid].stop() self.pushers[fullid].stop()
self.pushers[fullid] = p self.pushers[fullid] = p
p.start() preserve_fn(p.start)()
logger.info("Started pushers") logger.info("Started pushers")

View File

@ -19,10 +19,10 @@ logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"frozendict>=0.4": ["frozendict"], "frozendict>=0.4": ["frozendict"],
"unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"], "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
"signedjson>=1.0.0": ["signedjson>=1.0.0"], "signedjson>=1.0.0": ["signedjson>=1.0.0"],
"pynacl>=0.3.0": ["nacl>=0.3.0", "nacl.bindings"], "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=15.1.0": ["twisted>=15.1.0"], "Twisted>=15.1.0": ["twisted>=15.1.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],
@ -34,12 +34,12 @@ REQUIREMENTS = {
"pydenticon": ["pydenticon"], "pydenticon": ["pydenticon"],
"ujson": ["ujson"], "ujson": ["ujson"],
"blist": ["blist"], "blist": ["blist"],
"pysaml2": ["saml2"], "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
"pymacaroons-pynacl": ["pymacaroons"], "pymacaroons-pynacl": ["pymacaroons"],
} }
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
"web_client": { "web_client": {
"matrix_angular_sdk>=0.6.6": ["syweb>=0.6.6"], "matrix_angular_sdk>=0.6.8": ["syweb>=0.6.8"],
} }
} }

View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,367 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.server import request_handler, finish_request
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
import ujson as json
import collections
import logging
logger = logging.getLogger(__name__)
REPLICATION_PREFIX = "/_synapse/replication"
STREAM_NAMES = (
("events",),
("presence",),
("typing",),
("receipts",),
("user_account_data", "room_account_data", "tag_account_data",),
("backfill",),
("push_rules",),
("pushers",),
)
class ReplicationResource(Resource):
"""
HTTP endpoint for extracting data from synapse.
The streams of data returned by the endpoint are controlled by the
parameters given to the API. To return a given stream pass a query
parameter with a position in the stream to return data from or the
special value "-1" to return data from the start of the stream.
If there is no data for any of the supplied streams after the given
position then the request will block until there is data for one
of the streams. This allows clients to long-poll this API.
The possible streams are:
* "streams": A special stream returing the positions of other streams.
* "events": The new events seen on the server.
* "presence": Presence updates.
* "typing": Typing updates.
* "receipts": Receipt updates.
* "user_account_data": Top-level per user account data.
* "room_account_data: Per room per user account data.
* "tag_account_data": Per room per user tags.
* "backfill": Old events that have been backfilled from other servers.
* "push_rules": Per user changes to push rules.
* "pushers": Per user changes to their pushers.
The API takes two additional query parameters:
* "timeout": How long to wait before returning an empty response.
* "limit": The maximum number of rows to return for the selected streams.
The response is a JSON object with keys for each stream with updates. Under
each key is a JSON object with:
* "position": The current position of the stream.
* "field_names": The names of the fields in each row.
* "rows": The updates as an array of arrays.
There are a number of ways this API could be used:
1) To replicate the contents of the backing database to another database.
2) To be notified when the contents of a shared backing database changes.
3) To "tail" the activity happening on a server for debugging.
In the first case the client would track all of the streams and store it's
own copy of the data.
In the second case the client might theoretically just be able to follow
the "streams" stream to track where the other streams are. However in
practise it will probably need to get the contents of the streams in
order to expire the any in-memory caches. Whether it gets the contents
of the streams from this replication API or directly from the backing
store is a matter of taste.
In the third case the client would use the "streams" stream to find what
streams are available and their current positions. Then it can start
long-polling this replication API for new data on those streams.
"""
isLeaf = True
def __init__(self, hs):
Resource.__init__(self) # Resource is old-style, so no super()
self.version_string = hs.version_string
self.store = hs.get_datastore()
self.sources = hs.get_event_sources()
self.presence_handler = hs.get_handlers().presence_handler
self.typing_handler = hs.get_handlers().typing_notification_handler
self.notifier = hs.notifier
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@defer.inlineCallbacks
def current_replication_token(self):
stream_token = yield self.sources.get_current_token()
backfill_token = yield self.store.get_current_backfill_token()
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token()
defer.returnValue(_ReplicationToken(
room_stream_token,
int(stream_token.presence_key),
int(stream_token.typing_key),
int(stream_token.receipt_key),
int(stream_token.account_data_key),
backfill_token,
push_rules_token,
pushers_token,
))
@request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
limit = parse_integer(request, "limit", 100)
timeout = parse_integer(request, "timeout", 10 * 1000)
request.setHeader(b"Content-Type", b"application/json")
writer = _Writer(request)
@defer.inlineCallbacks
def replicate():
current_token = yield self.current_replication_token()
logger.info("Replicating up to %r", current_token)
yield self.account_data(writer, current_token, limit)
yield self.events(writer, current_token, limit)
yield self.presence(writer, current_token) # TODO: implement limit
yield self.typing(writer, current_token) # TODO: implement limit
yield self.receipts(writer, current_token, limit)
yield self.push_rules(writer, current_token, limit)
yield self.pushers(writer, current_token, limit)
self.streams(writer, current_token)
logger.info("Replicated %d rows", writer.total)
defer.returnValue(writer.total)
yield self.notifier.wait_for_replication(replicate, timeout)
writer.finish()
def streams(self, writer, current_token):
request_token = parse_string(writer.request, "streams")
streams = []
if request_token is not None:
if request_token == "-1":
for names, position in zip(STREAM_NAMES, current_token):
streams.extend((name, position) for name in names)
else:
items = zip(
STREAM_NAMES,
current_token,
_ReplicationToken(request_token)
)
for names, current_id, last_id in items:
if last_id < current_id:
streams.extend((name, current_id) for name in names)
if streams:
writer.write_header_and_rows(
"streams", streams, ("name", "position"),
position=str(current_token)
)
@defer.inlineCallbacks
def events(self, writer, current_token, limit):
request_events = parse_integer(writer.request, "events")
request_backfill = parse_integer(writer.request, "backfill")
if request_events is not None or request_backfill is not None:
if request_events is None:
request_events = current_token.events
if request_backfill is None:
request_backfill = current_token.backfill
events_rows, backfill_rows = yield self.store.get_all_new_events(
request_backfill, request_events,
current_token.backfill, current_token.events,
limit
)
writer.write_header_and_rows(
"events", events_rows, ("position", "internal", "json")
)
writer.write_header_and_rows(
"backfill", backfill_rows, ("position", "internal", "json")
)
@defer.inlineCallbacks
def presence(self, writer, current_token):
current_position = current_token.presence
request_presence = parse_integer(writer.request, "presence")
if request_presence is not None:
presence_rows = yield self.presence_handler.get_all_presence_updates(
request_presence, current_position
)
writer.write_header_and_rows("presence", presence_rows, (
"position", "user_id", "state", "last_active_ts",
"last_federation_update_ts", "last_user_sync_ts",
"status_msg", "currently_active",
))
@defer.inlineCallbacks
def typing(self, writer, current_token):
current_position = current_token.presence
request_typing = parse_integer(writer.request, "typing")
if request_typing is not None:
typing_rows = yield self.typing_handler.get_all_typing_updates(
request_typing, current_position
)
writer.write_header_and_rows("typing", typing_rows, (
"position", "room_id", "typing"
))
@defer.inlineCallbacks
def receipts(self, writer, current_token, limit):
current_position = current_token.receipts
request_receipts = parse_integer(writer.request, "receipts")
if request_receipts is not None:
receipts_rows = yield self.store.get_all_updated_receipts(
request_receipts, current_position, limit
)
writer.write_header_and_rows("receipts", receipts_rows, (
"position", "room_id", "receipt_type", "user_id", "event_id", "data"
))
@defer.inlineCallbacks
def account_data(self, writer, current_token, limit):
current_position = current_token.account_data
user_account_data = parse_integer(writer.request, "user_account_data")
room_account_data = parse_integer(writer.request, "room_account_data")
tag_account_data = parse_integer(writer.request, "tag_account_data")
if user_account_data is not None or room_account_data is not None:
if user_account_data is None:
user_account_data = current_position
if room_account_data is None:
room_account_data = current_position
user_rows, room_rows = yield self.store.get_all_updated_account_data(
user_account_data, room_account_data, current_position, limit
)
writer.write_header_and_rows("user_account_data", user_rows, (
"position", "user_id", "type", "content"
))
writer.write_header_and_rows("room_account_data", room_rows, (
"position", "user_id", "room_id", "type", "content"
))
if tag_account_data is not None:
tag_rows = yield self.store.get_all_updated_tags(
tag_account_data, current_position, limit
)
writer.write_header_and_rows("tag_account_data", tag_rows, (
"position", "user_id", "room_id", "tags"
))
@defer.inlineCallbacks
def push_rules(self, writer, current_token, limit):
current_position = current_token.push_rules
push_rules = parse_integer(writer.request, "push_rules")
if push_rules is not None:
rows = yield self.store.get_all_push_rule_updates(
push_rules, current_position, limit
)
writer.write_header_and_rows("push_rules", rows, (
"position", "event_stream_ordering", "user_id", "rule_id", "op",
"priority_class", "priority", "conditions", "actions"
))
@defer.inlineCallbacks
def pushers(self, writer, current_token, limit):
current_position = current_token.pushers
pushers = parse_integer(writer.request, "pushers")
if pushers is not None:
updated, deleted = yield self.store.get_all_updated_pushers(
pushers, current_position, limit
)
writer.write_header_and_rows("pushers", updated, (
"position", "user_id", "access_token", "profile_tag", "kind",
"app_id", "app_display_name", "device_display_name", "pushkey",
"ts", "lang", "data"
))
writer.write_header_and_rows("deleted", deleted, (
"position", "user_id", "app_id", "pushkey"
))
class _Writer(object):
"""Writes the streams as a JSON object as the response to the request"""
def __init__(self, request):
self.streams = {}
self.request = request
self.total = 0
def write_header_and_rows(self, name, rows, fields, position=None):
if not rows:
return
if position is None:
position = rows[-1][0]
self.streams[name] = {
"position": str(position),
"field_names": fields,
"rows": rows,
}
self.total += len(rows)
def finish(self):
self.request.write(json.dumps(self.streams, ensure_ascii=False))
finish_request(self.request)
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers"
))):
__slots__ = []
def __new__(cls, *args):
if len(args) == 1:
streams = [int(value) for value in args[0].split("_")]
if len(streams) < len(cls._fields):
streams.extend([0] * (len(cls._fields) - len(streams)))
return cls(*streams)
else:
return super(_ReplicationToken, cls).__new__(cls, *args)
def __str__(self):
return "_".join(str(value) for value in self)

View File

@ -30,6 +30,7 @@ from synapse.rest.client.v1 import (
push_rule, push_rule,
register as v1_register, register as v1_register,
login as v1_login, login as v1_login,
logout,
) )
from synapse.rest.client.v2_alpha import ( from synapse.rest.client.v2_alpha import (
@ -72,6 +73,7 @@ class ClientRestResource(JsonResource):
admin.register_servlets(hs, client_resource) admin.register_servlets(hs, client_resource)
pusher.register_servlets(hs, client_resource) pusher.register_servlets(hs, client_resource)
push_rule.register_servlets(hs, client_resource) push_rule.register_servlets(hs, client_resource)
logout.register_servlets(hs, client_resource)
# "v2" # "v2"
sync.register_servlets(hs, client_resource) sync.register_servlets(hs, client_resource)

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.types import UserID from synapse.types import UserID
from base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import logging import logging

View File

@ -18,9 +18,10 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError, Codes from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.types import RoomAlias from synapse.types import RoomAlias
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json
import logging import logging
@ -45,7 +46,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_alias): def on_PUT(self, request, room_alias):
content = _parse_json(request) content = parse_json_object_from_request(request)
if "room_id" not in content: if "room_id" not in content:
raise SynapseError(400, "Missing room_id key", raise SynapseError(400, "Missing room_id key",
errcode=Codes.BAD_JSON) errcode=Codes.BAD_JSON)
@ -75,7 +76,11 @@ class ClientDirectoryServer(ClientV1RestServlet):
yield dir_handler.create_association( yield dir_handler.create_association(
user_id, room_alias, room_id, servers user_id, room_alias, room_id, servers
) )
yield dir_handler.send_room_alias_update_event(user_id, room_id) yield dir_handler.send_room_alias_update_event(
requester,
user_id,
room_id
)
except SynapseError as e: except SynapseError as e:
raise e raise e
except: except:
@ -118,15 +123,13 @@ class ClientDirectoryServer(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = requester.user user = requester.user
is_admin = yield self.auth.is_server_admin(user)
if not is_admin:
raise AuthError(403, "You need to be a server admin")
room_alias = RoomAlias.from_string(room_alias) room_alias = RoomAlias.from_string(room_alias)
yield dir_handler.delete_association( yield dir_handler.delete_association(
user.to_string(), room_alias requester, user.to_string(), room_alias
) )
logger.info( logger.info(
"User %s deleted alias %s", "User %s deleted alias %s",
user.to_string(), user.to_string(),
@ -134,14 +137,3 @@ class ClientDirectoryServer(ClientV1RestServlet):
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)

View File

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
# TODO: Needs unit testing # TODO: Needs unit testing

View File

@ -17,7 +17,10 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, LoginError, Codes from synapse.api.errors import SynapseError, LoginError, Codes
from synapse.types import UserID from synapse.types import UserID
from base import ClientV1RestServlet, client_path_patterns from synapse.http.server import finish_request
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json import simplejson as json
import urllib import urllib
@ -77,7 +80,7 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
login_submission = _parse_json(request) login_submission = parse_json_object_from_request(request)
try: try:
if login_submission["type"] == LoginRestServlet.PASS_TYPE: if login_submission["type"] == LoginRestServlet.PASS_TYPE:
if not self.password_enabled: if not self.password_enabled:
@ -89,7 +92,7 @@ class LoginRestServlet(ClientV1RestServlet):
LoginRestServlet.SAML2_TYPE): LoginRestServlet.SAML2_TYPE):
relay_state = "" relay_state = ""
if "relay_state" in login_submission: if "relay_state" in login_submission:
relay_state = "&RelayState="+urllib.quote( relay_state = "&RelayState=" + urllib.quote(
login_submission["relay_state"]) login_submission["relay_state"])
result = { result = {
"uri": "%s%s" % (self.idp_redirect_url, relay_state) "uri": "%s%s" % (self.idp_redirect_url, relay_state)
@ -250,7 +253,7 @@ class SAML2RestServlet(ClientV1RestServlet):
SP = Saml2Client(conf) SP = Saml2Client(conf)
saml2_auth = SP.parse_authn_request_response( saml2_auth = SP.parse_authn_request_response(
request.args['SAMLResponse'][0], BINDING_HTTP_POST) request.args['SAMLResponse'][0], BINDING_HTTP_POST)
except Exception, e: # Not authenticated except Exception as e: # Not authenticated
logger.exception(e) logger.exception(e)
if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed: if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
username = saml2_auth.name_id.text username = saml2_auth.name_id.text
@ -263,7 +266,7 @@ class SAML2RestServlet(ClientV1RestServlet):
'?status=authenticated&access_token=' + '?status=authenticated&access_token=' +
token + '&user_id=' + user_id + '&ava=' + token + '&user_id=' + user_id + '&ava=' +
urllib.quote(json.dumps(saml2_auth.ava))) urllib.quote(json.dumps(saml2_auth.ava)))
request.finish() finish_request(request)
defer.returnValue(None) defer.returnValue(None)
defer.returnValue((200, {"status": "authenticated", defer.returnValue((200, {"status": "authenticated",
"user_id": user_id, "token": token, "user_id": user_id, "token": token,
@ -272,7 +275,7 @@ class SAML2RestServlet(ClientV1RestServlet):
request.redirect(urllib.unquote( request.redirect(urllib.unquote(
request.args['RelayState'][0]) + request.args['RelayState'][0]) +
'?status=not_authenticated') '?status=not_authenticated')
request.finish() finish_request(request)
defer.returnValue(None) defer.returnValue(None)
defer.returnValue((200, {"status": "not_authenticated"})) defer.returnValue((200, {"status": "not_authenticated"}))
@ -309,7 +312,7 @@ class CasRedirectServlet(ClientV1RestServlet):
"service": "%s?%s" % (hs_redirect_url, client_redirect_url_param) "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
}) })
request.redirect("%s?%s" % (self.cas_server_url, service_param)) request.redirect("%s?%s" % (self.cas_server_url, service_param))
request.finish() finish_request(request)
class CasTicketServlet(ClientV1RestServlet): class CasTicketServlet(ClientV1RestServlet):
@ -362,7 +365,7 @@ class CasTicketServlet(ClientV1RestServlet):
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token) login_token)
request.redirect(redirect_url) request.redirect(redirect_url)
request.finish() finish_request(request)
def add_login_token_to_redirect_url(self, url, token): def add_login_token_to_redirect_url(self, url, token):
url_parts = list(urlparse.urlparse(url)) url_parts = list(urlparse.urlparse(url))
@ -398,16 +401,6 @@ class CasTicketServlet(ClientV1RestServlet):
return (user, attributes) return (user, attributes)
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except ValueError:
raise SynapseError(400, "Content not JSON.")
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server) LoginRestServlet(hs).register(http_server)
if hs.config.saml2_enabled: if hs.config.saml2_enabled:

View File

@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError, Codes
from .base import ClientV1RestServlet, client_path_patterns
import logging
logger = logging.getLogger(__name__)
class LogoutRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/logout$")
def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
def on_OPTIONS(self, request):
return (200, {})
@defer.inlineCallbacks
def on_POST(self, request):
try:
access_token = request.args["access_token"][0]
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
errcode=Codes.MISSING_TOKEN
)
yield self.store.delete_access_token(access_token)
defer.returnValue((200, {}))
class LogoutAllRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/logout/all$")
def __init__(self, hs):
super(LogoutAllRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.auth = hs.get_auth()
def on_OPTIONS(self, request):
return (200, {})
@defer.inlineCallbacks
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
yield self.store.user_delete_access_tokens(user_id)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
LogoutRestServlet(hs).register(http_server)
LogoutAllRestServlet(hs).register(http_server)

View File

@ -17,11 +17,11 @@
""" """
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, AuthError
from synapse.types import UserID from synapse.types import UserID
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,8 +35,15 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
state = yield self.handlers.presence_handler.get_state( if requester.user != user:
target_user=user, auth_user=requester.user) allowed = yield self.handlers.presence_handler.is_visible(
observed_user=user, observer_user=requester.user,
)
if not allowed:
raise AuthError(403, "You are not allowed to see their presence.")
state = yield self.handlers.presence_handler.get_state(target_user=user)
defer.returnValue((200, state)) defer.returnValue((200, state))
@ -45,10 +52,14 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
state = {} if requester.user != user:
try: raise AuthError(403, "Can only set your own presence state")
content = json.loads(request.content.read())
state = {}
content = parse_json_object_from_request(request)
try:
state["presence"] = content.pop("presence") state["presence"] = content.pop("presence")
if "status_msg" in content: if "status_msg" in content:
@ -63,8 +74,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
except: except:
raise SynapseError(400, "Unable to parse state") raise SynapseError(400, "Unable to parse state")
yield self.handlers.presence_handler.set_state( yield self.handlers.presence_handler.set_state(user, state)
target_user=user, auth_user=requester.user, state=state)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -87,11 +97,8 @@ class PresenceListRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Cannot get another user's presence list") raise SynapseError(400, "Cannot get another user's presence list")
presence = yield self.handlers.presence_handler.get_presence_list( presence = yield self.handlers.presence_handler.get_presence_list(
observer_user=user, accepted=True) observer_user=user, accepted=True
)
for p in presence:
observed_user = p.pop("observed_user")
p["user_id"] = observed_user.to_string()
defer.returnValue((200, presence)) defer.returnValue((200, presence))
@ -107,11 +114,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
raise SynapseError( raise SynapseError(
400, "Cannot modify another user's presence list") 400, "Cannot modify another user's presence list")
try: content = parse_json_object_from_request(request)
content = json.loads(request.content.read())
except:
logger.exception("JSON parse error")
raise SynapseError(400, "Unable to parse content")
if "invite" in content: if "invite" in content:
for u in content["invite"]: for u in content["invite"]:

View File

@ -18,8 +18,7 @@ from twisted.internet import defer
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
from synapse.types import UserID from synapse.types import UserID
from synapse.http.servlet import parse_json_object_from_request
import simplejson as json
class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileDisplaynameRestServlet(ClientV1RestServlet):
@ -33,21 +32,26 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
user, user,
) )
defer.returnValue((200, {"displayname": displayname})) ret = {}
if displayname is not None:
ret["displayname"] = displayname
defer.returnValue((200, ret))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
content = parse_json_object_from_request(request)
try: try:
content = json.loads(request.content.read())
new_name = content["displayname"] new_name = content["displayname"]
except: except:
defer.returnValue((400, "Unable to parse name")) defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_displayname( yield self.handlers.profile_handler.set_displayname(
user, requester.user, new_name) user, requester, new_name)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -66,21 +70,25 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
user, user,
) )
defer.returnValue((200, {"avatar_url": avatar_url})) ret = {}
if avatar_url is not None:
ret["avatar_url"] = avatar_url
defer.returnValue((200, ret))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
content = parse_json_object_from_request(request)
try: try:
content = json.loads(request.content.read())
new_name = content["avatar_url"] new_name = content["avatar_url"]
except: except:
defer.returnValue((400, "Unable to parse name")) defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_avatar_url( yield self.handlers.profile_handler.set_avatar_url(
user, requester.user, new_name) user, requester, new_name)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -102,10 +110,13 @@ class ProfileRestServlet(ClientV1RestServlet):
user, user,
) )
defer.returnValue((200, { ret = {}
"displayname": displayname, if displayname is not None:
"avatar_url": avatar_url ret["displayname"] = displayname
})) if avatar_url is not None:
ret["avatar_url"] = avatar_url
defer.returnValue((200, ret))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):

View File

@ -16,19 +16,16 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import (
SynapseError, Codes, UnrecognizedRequestError, NotFoundError, StoreError SynapseError, UnrecognizedRequestError, NotFoundError, StoreError
) )
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
from synapse.storage.push_rule import ( from synapse.storage.push_rule import (
InconsistentRuleException, RuleNotFoundException InconsistentRuleException, RuleNotFoundException
) )
import synapse.push.baserules as baserules from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import ( from synapse.push.baserules import BASE_RULE_IDS
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP from synapse.push.rulekinds import PRIORITY_CLASS_MAP
) from synapse.http.servlet import parse_json_value_from_request
import copy
import simplejson as json
class PushRuleRestServlet(ClientV1RestServlet): class PushRuleRestServlet(ClientV1RestServlet):
@ -36,6 +33,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
"Unrecognised request: You probably wanted a trailing slash") "Unrecognised request: You probably wanted a trailing slash")
def __init__(self, hs):
super(PushRuleRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request): def on_PUT(self, request):
spec = _rule_spec_from_path(request.postpath) spec = _rule_spec_from_path(request.postpath)
@ -49,32 +51,39 @@ class PushRuleRestServlet(ClientV1RestServlet):
if '/' in spec['rule_id'] or '\\' in spec['rule_id']: if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
raise SynapseError(400, "rule_id may not contain slashes") raise SynapseError(400, "rule_id may not contain slashes")
content = _parse_json(request) content = parse_json_value_from_request(request)
user_id = requester.user.to_string()
if 'attr' in spec: if 'attr' in spec:
self.set_rule_attr(requester.user.to_string(), spec, content) yield self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
if spec['rule_id'].startswith('.'):
# Rule ids starting with '.' are reserved for server default rules.
raise SynapseError(400, "cannot add new rule_ids that start with '.'")
try: try:
(conditions, actions) = _rule_tuple_from_request_object( (conditions, actions) = _rule_tuple_from_request_object(
spec['template'], spec['template'],
spec['rule_id'], spec['rule_id'],
content, content,
device=spec['device'] if 'device' in spec else None
) )
except InvalidRuleException as e: except InvalidRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, e.message)
before = request.args.get("before", None) before = request.args.get("before", None)
if before and len(before): if before:
before = before[0] before = _namespaced_rule_id(spec, before[0])
after = request.args.get("after", None) after = request.args.get("after", None)
if after and len(after): if after:
after = after[0] after = _namespaced_rule_id(spec, after[0])
try: try:
yield self.hs.get_datastore().add_push_rule( yield self.store.add_push_rule(
user_id=requester.user.to_string(), user_id=user_id,
rule_id=_namespaced_rule_id_from_spec(spec), rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class, priority_class=priority_class,
conditions=conditions, conditions=conditions,
@ -82,6 +91,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
before=before, before=before,
after=after after=after
) )
self.notify_user(user_id)
except InconsistentRuleException as e: except InconsistentRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, e.message)
except RuleNotFoundException as e: except RuleNotFoundException as e:
@ -94,13 +104,15 @@ class PushRuleRestServlet(ClientV1RestServlet):
spec = _rule_spec_from_path(request.postpath) spec = _rule_spec_from_path(request.postpath)
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
try: try:
yield self.hs.get_datastore().delete_push_rule( yield self.store.delete_push_rule(
requester.user.to_string(), namespaced_rule_id user_id, namespaced_rule_id
) )
self.notify_user(user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
except StoreError as e: except StoreError as e:
if e.code == 404: if e.code == 404:
@ -111,74 +123,16 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = requester.user user_id = requester.user.to_string()
# we build up the full structure and then decide which bits of it # we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is # to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference # is probably not going to make a whole lot of difference
rawrules = yield self.hs.get_datastore().get_push_rules_for_user( rawrules = yield self.store.get_push_rules_for_user(user_id)
user.to_string()
)
ruleslist = [] enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
for rawrule in rawrules:
rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy rules = format_push_rules_for_user(requester.user, rawrules, enabled_map)
ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist))
rules = {'global': {}, 'device': {}}
rules['global'] = _add_empty_priority_class_arrays(rules['global'])
enabled_map = yield self.hs.get_datastore().\
get_push_rules_enabled_for_user(user.to_string())
for r in ruleslist:
rulearray = None
template_name = _priority_class_to_template_name(r['priority_class'])
# Remove internal stuff.
for c in r["conditions"]:
c.pop("_id", None)
pattern_type = c.pop("pattern_type", None)
if pattern_type == "user_id":
c["pattern"] = user.to_string()
elif pattern_type == "user_localpart":
c["pattern"] = user.localpart
if r['priority_class'] > PRIORITY_CLASS_MAP['override']:
# per-device rule
profile_tag = _profile_tag_from_conditions(r["conditions"])
r = _strip_device_condition(r)
if not profile_tag:
continue
if profile_tag not in rules['device']:
rules['device'][profile_tag] = {}
rules['device'][profile_tag] = (
_add_empty_priority_class_arrays(
rules['device'][profile_tag]
)
)
rulearray = rules['device'][profile_tag][template_name]
else:
rulearray = rules['global'][template_name]
template_rule = _rule_to_template(r)
if template_rule:
if r['rule_id'] in enabled_map:
template_rule['enabled'] = enabled_map[r['rule_id']]
elif 'enabled' in r:
template_rule['enabled'] = r['enabled']
else:
template_rule['enabled'] = True
rulearray.append(template_rule)
path = request.postpath[1:] path = request.postpath[1:]
@ -194,30 +148,18 @@ class PushRuleRestServlet(ClientV1RestServlet):
path = path[1:] path = path[1:]
result = _filter_ruleset_with_path(rules['global'], path) result = _filter_ruleset_with_path(rules['global'], path)
defer.returnValue((200, result)) defer.returnValue((200, result))
elif path[0] == 'device':
path = path[1:]
if path == []:
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
defer.returnValue((200, rules['device']))
profile_tag = path[0]
path = path[1:]
if profile_tag not in rules['device']:
ret = {}
ret = _add_empty_priority_class_arrays(ret)
defer.returnValue((200, ret))
ruleset = rules['device'][profile_tag]
result = _filter_ruleset_with_path(ruleset, path)
defer.returnValue((200, result))
else: else:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
def notify_user(self, user_id):
stream_id, _ = self.store.get_push_rules_stream_token()
self.notifier.on_new_event(
"push_rules_key", stream_id, users=[user_id]
)
def set_rule_attr(self, user_id, spec, val): def set_rule_attr(self, user_id, spec, val):
if spec['attr'] == 'enabled': if spec['attr'] == 'enabled':
if isinstance(val, dict) and "enabled" in val: if isinstance(val, dict) and "enabled" in val:
@ -228,16 +170,20 @@ class PushRuleRestServlet(ClientV1RestServlet):
# bools directly, so let's not break them. # bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean") raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
self.hs.get_datastore().set_push_rule_enabled( return self.store.set_push_rule_enabled(
user_id, namespaced_rule_id, val user_id, namespaced_rule_id, val
) )
else: elif spec['attr'] == 'actions':
raise UnrecognizedRequestError() actions = val.get('actions')
_check_actions(actions)
def get_rule_attr(self, user_id, namespaced_rule_id, attr): namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
if attr == 'enabled': rule_id = spec['rule_id']
return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id( is_default_rule = rule_id.startswith(".")
user_id, namespaced_rule_id if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
return self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule
) )
else: else:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
@ -251,16 +197,9 @@ def _rule_spec_from_path(path):
scope = path[1] scope = path[1]
path = path[2:] path = path[2:]
if scope not in ['global', 'device']: if scope != 'global':
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
device = None
if scope == 'device':
if len(path) == 0:
raise UnrecognizedRequestError()
device = path[0]
path = path[1:]
if len(path) == 0: if len(path) == 0:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
@ -277,8 +216,6 @@ def _rule_spec_from_path(path):
'template': template, 'template': template,
'rule_id': rule_id 'rule_id': rule_id
} }
if device:
spec['profile_tag'] = device
path = path[1:] path = path[1:]
@ -288,7 +225,7 @@ def _rule_spec_from_path(path):
return spec return spec
def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None): def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
if rule_template in ['override', 'underride']: if rule_template in ['override', 'underride']:
if 'conditions' not in req_obj: if 'conditions' not in req_obj:
raise InvalidRuleException("Missing 'conditions'") raise InvalidRuleException("Missing 'conditions'")
@ -321,16 +258,19 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None
else: else:
raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) raise InvalidRuleException("Unknown rule template: %s" % (rule_template,))
if device:
conditions.append({
'kind': 'device',
'profile_tag': device
})
if 'actions' not in req_obj: if 'actions' not in req_obj:
raise InvalidRuleException("No actions found") raise InvalidRuleException("No actions found")
actions = req_obj['actions'] actions = req_obj['actions']
_check_actions(actions)
return conditions, actions
def _check_actions(actions):
if not isinstance(actions, list):
raise InvalidRuleException("No actions found")
for a in actions: for a in actions:
if a in ['notify', 'dont_notify', 'coalesce']: if a in ['notify', 'dont_notify', 'coalesce']:
pass pass
@ -339,25 +279,6 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None
else: else:
raise InvalidRuleException("Unrecognised action") raise InvalidRuleException("Unrecognised action")
return conditions, actions
def _add_empty_priority_class_arrays(d):
for pc in PRIORITY_CLASS_MAP.keys():
d[pc] = []
return d
def _profile_tag_from_conditions(conditions):
"""
Given a list of conditions, return the profile tag of the
device rule if there is one
"""
for c in conditions:
if c['kind'] == 'device':
return c['profile_tag']
return None
def _filter_ruleset_with_path(ruleset, path): def _filter_ruleset_with_path(ruleset, path):
if path == []: if path == []:
@ -392,89 +313,32 @@ def _filter_ruleset_with_path(ruleset, path):
attr = path[0] attr = path[0]
if attr in the_rule: if attr in the_rule:
return the_rule[attr] # Make sure we return a JSON object as the attribute may be a
# JSON value.
return {attr: the_rule[attr]}
else: else:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
def _priority_class_from_spec(spec): def _priority_class_from_spec(spec):
if spec['template'] not in PRIORITY_CLASS_MAP.keys(): if spec['template'] not in PRIORITY_CLASS_MAP.keys():
raise InvalidRuleException("Unknown template: %s" % (spec['kind'])) raise InvalidRuleException("Unknown template: %s" % (spec['template']))
pc = PRIORITY_CLASS_MAP[spec['template']] pc = PRIORITY_CLASS_MAP[spec['template']]
if spec['scope'] == 'device':
pc += len(PRIORITY_CLASS_MAP)
return pc return pc
def _priority_class_to_template_name(pc):
if pc > PRIORITY_CLASS_MAP['override']:
# per-device
prio_class_index = pc - len(PRIORITY_CLASS_MAP)
return PRIORITY_CLASS_INVERSE_MAP[prio_class_index]
else:
return PRIORITY_CLASS_INVERSE_MAP[pc]
def _rule_to_template(rule):
unscoped_rule_id = None
if 'rule_id' in rule:
unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id'])
template_name = _priority_class_to_template_name(rule['priority_class'])
if template_name in ['override', 'underride']:
templaterule = {k: rule[k] for k in ["conditions", "actions"]}
elif template_name in ["sender", "room"]:
templaterule = {'actions': rule['actions']}
unscoped_rule_id = rule['conditions'][0]['pattern']
elif template_name == 'content':
if len(rule["conditions"]) != 1:
return None
thecond = rule["conditions"][0]
if "pattern" not in thecond:
return None
templaterule = {'actions': rule['actions']}
templaterule["pattern"] = thecond["pattern"]
if unscoped_rule_id:
templaterule['rule_id'] = unscoped_rule_id
if 'default' in rule:
templaterule['default'] = rule['default']
return templaterule
def _strip_device_condition(rule):
for i, c in enumerate(rule['conditions']):
if c['kind'] == 'device':
del rule['conditions'][i]
return rule
def _namespaced_rule_id_from_spec(spec): def _namespaced_rule_id_from_spec(spec):
if spec['scope'] == 'global': return _namespaced_rule_id(spec, spec['rule_id'])
scope = 'global'
else:
scope = 'device/%s' % (spec['profile_tag'])
return "%s/%s/%s" % (scope, spec['template'], spec['rule_id'])
def _rule_id_from_namespaced(in_rule_id): def _namespaced_rule_id(spec, rule_id):
return in_rule_id.split('/')[-1] return "global/%s/%s" % (spec['template'], rule_id)
class InvalidRuleException(Exception): class InvalidRuleException(Exception):
pass pass
# XXX: C+ped from rest/room.py - surely this should be common?
def _parse_json(request):
try:
content = json.loads(request.content.read())
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
PushRuleRestServlet(hs).register(http_server) PushRuleRestServlet(hs).register(http_server)

View File

@ -17,9 +17,10 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,12 +29,16 @@ logger = logging.getLogger(__name__)
class PusherRestServlet(ClientV1RestServlet): class PusherRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/pushers/set$") PATTERNS = client_path_patterns("/pushers/set$")
def __init__(self, hs):
super(PusherRestServlet, self).__init__(hs)
self.notifier = hs.get_notifier()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = requester.user user = requester.user
content = _parse_json(request) content = parse_json_object_from_request(request)
pusher_pool = self.hs.get_pusherpool() pusher_pool = self.hs.get_pusherpool()
@ -45,14 +50,14 @@ class PusherRestServlet(ClientV1RestServlet):
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
reqd = ['profile_tag', 'kind', 'app_id', 'app_display_name', reqd = ['kind', 'app_id', 'app_display_name',
'device_display_name', 'pushkey', 'lang', 'data'] 'device_display_name', 'pushkey', 'lang', 'data']
missing = [] missing = []
for i in reqd: for i in reqd:
if i not in content: if i not in content:
missing.append(i) missing.append(i)
if len(missing): if len(missing):
raise SynapseError(400, "Missing parameters: "+','.join(missing), raise SynapseError(400, "Missing parameters: " + ','.join(missing),
errcode=Codes.MISSING_PARAM) errcode=Codes.MISSING_PARAM)
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind']) logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
@ -73,36 +78,26 @@ class PusherRestServlet(ClientV1RestServlet):
yield pusher_pool.add_pusher( yield pusher_pool.add_pusher(
user_id=user.to_string(), user_id=user.to_string(),
access_token=requester.access_token_id, access_token=requester.access_token_id,
profile_tag=content['profile_tag'],
kind=content['kind'], kind=content['kind'],
app_id=content['app_id'], app_id=content['app_id'],
app_display_name=content['app_display_name'], app_display_name=content['app_display_name'],
device_display_name=content['device_display_name'], device_display_name=content['device_display_name'],
pushkey=content['pushkey'], pushkey=content['pushkey'],
lang=content['lang'], lang=content['lang'],
data=content['data'] data=content['data'],
profile_tag=content.get('profile_tag', ""),
) )
except PusherConfigException as pce: except PusherConfigException as pce:
raise SynapseError(400, "Config Error: "+pce.message, raise SynapseError(400, "Config Error: " + pce.message,
errcode=Codes.MISSING_PARAM) errcode=Codes.MISSING_PARAM)
self.notifier.on_new_replication_data()
defer.returnValue((200, {})) defer.returnValue((200, {}))
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
# XXX: C+ped from rest/room.py - surely this should be common?
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
PusherRestServlet(hs).register(http_server) PusherRestServlet(hs).register(http_server)

View File

@ -18,14 +18,14 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.http.servlet import parse_json_object_from_request
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from hashlib import sha1 from hashlib import sha1
import hmac import hmac
import simplejson as json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,7 +38,8 @@ logger = logging.getLogger(__name__)
if hasattr(hmac, "compare_digest"): if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest compare_digest = hmac.compare_digest
else: else:
compare_digest = lambda a, b: a == b def compare_digest(a, b):
return a == b
class RegisterRestServlet(ClientV1RestServlet): class RegisterRestServlet(ClientV1RestServlet):
@ -58,7 +59,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# } # }
# TODO: persistent storage # TODO: persistent storage
self.sessions = {} self.sessions = {}
self.disable_registration = hs.config.disable_registration self.enable_registration = hs.config.enable_registration
def on_GET(self, request): def on_GET(self, request):
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
@ -97,7 +98,7 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
register_json = _parse_json(request) register_json = parse_json_object_from_request(request)
session = (register_json["session"] session = (register_json["session"]
if "session" in register_json else None) if "session" in register_json else None)
@ -112,7 +113,7 @@ class RegisterRestServlet(ClientV1RestServlet):
is_using_shared_secret = login_type == LoginType.SHARED_SECRET is_using_shared_secret = login_type == LoginType.SHARED_SECRET
can_register = ( can_register = (
not self.disable_registration self.enable_registration
or is_application_server or is_application_server
or is_using_shared_secret or is_using_shared_secret
) )
@ -354,15 +355,5 @@ class RegisterRestServlet(ClientV1RestServlet):
) )
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except ValueError:
raise SynapseError(400, "Content not JSON.")
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server)

View File

@ -16,14 +16,14 @@
""" This module contains REST servlets to do with rooms: /rooms/<paths> """ """ This module contains REST servlets to do with rooms: /rooms/<paths> """
from twisted.internet import defer from twisted.internet import defer
from base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
from synapse.api.errors import SynapseError, Codes, AuthError from synapse.api.errors import SynapseError, Codes, AuthError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.http.servlet import parse_json_object_from_request
import simplejson as json
import logging import logging
import urllib import urllib
@ -63,35 +63,18 @@ class RoomCreateRestServlet(ClientV1RestServlet):
def on_POST(self, request): def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
room_config = self.get_room_config(request)
info = yield self.make_room(
room_config,
requester.user,
None,
)
room_config.update(info)
defer.returnValue((200, info))
@defer.inlineCallbacks
def make_room(self, room_config, auth_user, room_id):
handler = self.handlers.room_creation_handler handler = self.handlers.room_creation_handler
info = yield handler.create_room( info = yield handler.create_room(
user_id=auth_user.to_string(), requester, self.get_room_config(request)
room_id=room_id,
config=room_config
) )
defer.returnValue(info)
defer.returnValue((200, info))
def get_room_config(self, request): def get_room_config(self, request):
try: user_supplied_config = parse_json_object_from_request(request)
user_supplied_config = json.loads(request.content.read())
if "visibility" not in user_supplied_config:
# default visibility # default visibility
user_supplied_config["visibility"] = "public" user_supplied_config.setdefault("visibility", "public")
return user_supplied_config return user_supplied_config
except (ValueError, TypeError):
raise SynapseError(400, "Body must be JSON.",
errcode=Codes.BAD_JSON)
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return (200, {}) return (200, {})
@ -149,7 +132,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = parse_json_object_from_request(request)
event_dict = { event_dict = {
"type": event_type, "type": event_type,
@ -162,11 +145,22 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
event_dict["state_key"] = state_key event_dict["state_key"] = state_key
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_event( event, context = yield msg_handler.create_event(
event_dict, token_id=requester.access_token_id, txn_id=txn_id, event_dict,
token_id=requester.access_token_id,
txn_id=txn_id,
) )
defer.returnValue((200, {})) if event_type == EventTypes.Member:
yield self.handlers.room_member_handler.send_membership_event(
requester,
event,
context,
)
else:
yield msg_handler.send_nonmember_event(requester, event, context)
defer.returnValue((200, {"event_id": event.event_id}))
# TODO: Needs unit testing for generic events + feedback # TODO: Needs unit testing for generic events + feedback
@ -180,17 +174,17 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_type, txn_id=None): def on_POST(self, request, room_id, event_type, txn_id=None):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = _parse_json(request) content = parse_json_object_from_request(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_event( event = yield msg_handler.create_and_send_nonmember_event(
requester,
{ {
"type": event_type, "type": event_type,
"content": content, "content": content,
"room_id": room_id, "room_id": room_id,
"sender": requester.user.to_string(), "sender": requester.user.to_string(),
}, },
token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
) )
@ -229,46 +223,37 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
allow_guest=True, allow_guest=True,
) )
# the identifier could be a room alias or a room id. Try one then the
# other if it fails to parse, without swallowing other valid
# SynapseErrors.
identifier = None
is_room_alias = False
try: try:
identifier = RoomAlias.from_string(room_identifier) content = parse_json_object_from_request(request)
is_room_alias = True except:
except SynapseError: # Turns out we used to ignore the body entirely, and some clients
identifier = RoomID.from_string(room_identifier) # cheekily send invalid bodies.
content = {}
# TODO: Support for specifying the home server to join with? if RoomID.is_valid(room_identifier):
room_id = room_identifier
if is_room_alias: remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
handler = self.handlers.room_member_handler handler = self.handlers.room_member_handler
ret_dict = yield handler.join_room_alias( room_alias = RoomAlias.from_string(room_identifier)
requester.user, room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
identifier, room_id = room_id.to_string()
) else:
defer.returnValue((200, ret_dict)) raise SynapseError(400, "%s was not legal room ID or room alias" % (
else: # room id room_identifier,
msg_handler = self.handlers.message_handler ))
content = {"membership": Membership.JOIN}
if requester.is_guest: yield self.handlers.room_member_handler.update_membership(
content["kind"] = "guest" requester=requester,
yield msg_handler.create_and_send_event( target=requester.user,
{ room_id=room_id,
"type": EventTypes.Member, action="join",
"content": content,
"room_id": identifier.to_string(),
"sender": requester.user.to_string(),
"state_key": requester.user.to_string(),
},
token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
is_guest=requester.is_guest, remote_room_hosts=remote_room_hosts,
third_party_signed=content.get("third_party_signed", None),
) )
defer.returnValue((200, {"room_id": identifier.to_string()})) defer.returnValue((200, {"room_id": room_id}))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_identifier, txn_id): def on_PUT(self, request, room_identifier, txn_id):
@ -316,18 +301,6 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
if event["type"] != EventTypes.Member: if event["type"] != EventTypes.Member:
continue continue
chunk.append(event) chunk.append(event)
# FIXME: should probably be state_key here, not user_id
target_user = UserID.from_string(event["user_id"])
# Presence is an optional cache; don't fail if we can't fetch it
try:
presence_handler = self.handlers.presence_handler
presence_state = yield presence_handler.get_state(
target_user=target_user,
auth_user=requester.user,
)
event["content"].update(presence_state)
except:
pass
defer.returnValue((200, { defer.returnValue((200, {
"chunk": chunk "chunk": chunk
@ -429,8 +402,6 @@ class RoomEventContext(ClientV1RestServlet):
serialize_event(event, time_now) for event in results["state"] serialize_event(event, time_now) for event in results["state"]
] ]
logger.info("Responding with %r", results)
defer.returnValue((200, results)) defer.returnValue((200, results))
@ -456,7 +427,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
}: }:
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
content = _parse_json(request) try:
content = parse_json_object_from_request(request)
except:
# Turns out we used to ignore the body entirely, and some clients
# cheekily send invalid bodies.
content = {}
if membership_action == "invite" and self._has_3pid_invite_keys(content): if membership_action == "invite" and self._has_3pid_invite_keys(content):
yield self.handlers.room_member_handler.do_3pid_invite( yield self.handlers.room_member_handler.do_3pid_invite(
@ -465,7 +441,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
content["medium"], content["medium"],
content["address"], content["address"],
content["id_server"], content["id_server"],
requester.access_token_id, requester,
txn_id txn_id
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -483,6 +459,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
room_id=room_id, room_id=room_id,
action=membership_action, action=membership_action,
txn_id=txn_id, txn_id=txn_id,
third_party_signed=content.get("third_party_signed", None),
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -518,10 +495,11 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_id, txn_id=None): def on_POST(self, request, room_id, event_id, txn_id=None):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = parse_json_object_from_request(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_event( event = yield msg_handler.create_and_send_nonmember_event(
requester,
{ {
"type": EventTypes.Redaction, "type": EventTypes.Redaction,
"content": content, "content": content,
@ -529,7 +507,6 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
"sender": requester.user.to_string(), "sender": requester.user.to_string(),
"redacts": event_id, "redacts": event_id,
}, },
token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
) )
@ -555,6 +532,10 @@ class RoomTypingRestServlet(ClientV1RestServlet):
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$"
) )
def __init__(self, hs):
super(RoomTypingRestServlet, self).__init__(hs)
self.presence_handler = hs.get_handlers().presence_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id): def on_PUT(self, request, room_id, user_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
@ -562,10 +543,12 @@ class RoomTypingRestServlet(ClientV1RestServlet):
room_id = urllib.unquote(room_id) room_id = urllib.unquote(room_id)
target_user = UserID.from_string(urllib.unquote(user_id)) target_user = UserID.from_string(urllib.unquote(user_id))
content = _parse_json(request) content = parse_json_object_from_request(request)
typing_handler = self.handlers.typing_notification_handler typing_handler = self.handlers.typing_notification_handler
yield self.presence_handler.bump_presence_active_time(requester.user)
if content["typing"]: if content["typing"]:
yield typing_handler.started_typing( yield typing_handler.started_typing(
target_user=target_user, target_user=target_user,
@ -592,7 +575,7 @@ class SearchRestServlet(ClientV1RestServlet):
def on_POST(self, request): def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = parse_json_object_from_request(request)
batch = request.args.get("next_batch", [None])[0] batch = request.args.get("next_batch", [None])[0]
results = yield self.handlers.search_handler.search( results = yield self.handlers.search_handler.search(
@ -604,17 +587,6 @@ class SearchRestServlet(ClientV1RestServlet):
defer.returnValue((200, results)) defer.returnValue((200, results))
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_txn_path(servlet, regex_string, http_server, with_get=False): def register_txn_path(servlet, regex_string, http_server, with_get=False):
"""Registers a transaction-based path. """Registers a transaction-based path.

View File

@ -15,7 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import hmac import hmac

View File

@ -17,11 +17,9 @@
""" """
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
from synapse.api.errors import SynapseError
import re import re
import logging import logging
import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,23 +42,3 @@ def client_v2_patterns(path_regex, releases=(0,)):
new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
patterns.append(re.compile("^" + new_prefix + path_regex)) patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns return patterns
def parse_request_allow_empty(request):
content = request.content.read()
if content is None or content == '':
return None
try:
return simplejson.loads(content)
except simplejson.JSONDecodeError:
raise SynapseError(400, "Content not JSON.")
def parse_json_dict_from_request(request):
try:
content = simplejson.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except simplejson.JSONDecodeError:
raise SynapseError(400, "Content not JSON.")

View File

@ -17,10 +17,10 @@ from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import LoginError, SynapseError, Codes from synapse.api.errors import LoginError, SynapseError, Codes
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from ._base import client_v2_patterns, parse_json_dict_from_request from ._base import client_v2_patterns
import logging import logging
@ -41,9 +41,9 @@ class PasswordRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
yield run_on_reactor() yield run_on_reactor()
body = parse_json_dict_from_request(request) body = parse_json_object_from_request(request)
authed, result, params = yield self.auth_handler.check_auth([ authed, result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD], [LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY] [LoginType.EMAIL_IDENTITY]
], body, self.hs.get_ip_from_request(request)) ], body, self.hs.get_ip_from_request(request))
@ -79,7 +79,7 @@ class PasswordRestServlet(RestServlet):
new_password = params['new_password'] new_password = params['new_password']
yield self.auth_handler.set_password( yield self.auth_handler.set_password(
user_id, new_password user_id, new_password, requester
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -114,11 +114,12 @@ class ThreepidRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
yield run_on_reactor() yield run_on_reactor()
body = parse_json_dict_from_request(request) body = parse_json_object_from_request(request)
if 'threePidCreds' not in body: threePidCreds = body.get('threePidCreds')
threePidCreds = body.get('three_pid_creds', threePidCreds)
if threePidCreds is None:
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
threePidCreds = body['threePidCreds']
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()

View File

@ -15,15 +15,13 @@
from ._base import client_v2_patterns from ._base import client_v2_patterns
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError
from twisted.internet import defer from twisted.internet import defer
import logging import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,17 +45,13 @@ class AccountDataServlet(RestServlet):
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.") raise AuthError(403, "Cannot add account data for other users.")
try: body = parse_json_object_from_request(request)
content_bytes = request.content.read()
body = json.loads(content_bytes)
except:
raise SynapseError(400, "Invalid JSON")
max_id = yield self.store.add_account_data_for_user( max_id = yield self.store.add_account_data_for_user(
user_id, account_data_type, body user_id, account_data_type, body
) )
yield self.notifier.on_new_event( self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id] "account_data_key", max_id, users=[user_id]
) )
@ -86,20 +80,13 @@ class RoomAccountDataServlet(RestServlet):
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.") raise AuthError(403, "Cannot add account data for other users.")
try: body = parse_json_object_from_request(request)
content_bytes = request.content.read()
body = json.loads(content_bytes)
except:
raise SynapseError(400, "Invalid JSON")
if not isinstance(body, dict):
raise ValueError("Expected a JSON object")
max_id = yield self.store.add_account_data_to_room( max_id = yield self.store.add_account_data_to_room(
user_id, room_id, account_data_type, body user_id, room_id, account_data_type, body
) )
yield self.notifier.on_new_event( self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id] "account_data_key", max_id, users=[user_id]
) )

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_v2_patterns
@ -130,7 +131,7 @@ class AuthRestServlet(RestServlet):
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes) request.write(html_bytes)
request.finish() finish_request(request)
defer.returnValue(None) defer.returnValue(None)
else: else:
raise SynapseError(404, "Unknown auth stage type") raise SynapseError(404, "Unknown auth stage type")
@ -176,7 +177,7 @@ class AuthRestServlet(RestServlet):
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes) request.write(html_bytes)
request.finish() finish_request(request)
defer.returnValue(None) defer.returnValue(None)
else: else:

View File

@ -16,12 +16,11 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID from synapse.types import UserID
from ._base import client_v2_patterns from ._base import client_v2_patterns
import simplejson as json
import logging import logging
@ -59,7 +58,7 @@ class GetFilterRestServlet(RestServlet):
filter_id=filter_id, filter_id=filter_id,
) )
defer.returnValue((200, filter.filter_json)) defer.returnValue((200, filter.get_filter_json()))
except KeyError: except KeyError:
raise SynapseError(400, "No such filter") raise SynapseError(400, "No such filter")
@ -84,12 +83,7 @@ class CreateFilterRestServlet(RestServlet):
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only create filters for local users") raise SynapseError(400, "Can only create filters for local users")
try: content = parse_json_object_from_request(request)
content = json.loads(request.content.read())
# TODO(paul): check for required keys and invalid keys
except:
raise SynapseError(400, "Invalid filter definition")
filter_id = yield self.filtering.add_user_filter( filter_id = yield self.filtering.add_user_filter(
user_localpart=target_user.localpart, user_localpart=target_user.localpart,

View File

@ -15,16 +15,15 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.servlet import RestServlet
from synapse.types import UserID from synapse.types import UserID
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from ._base import client_v2_patterns from ._base import client_v2_patterns
import simplejson as json
import logging import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,10 +67,9 @@ class KeyUploadServlet(RestServlet):
user_id = requester.user.to_string() user_id = requester.user.to_string()
# TODO: Check that the device_id matches that in the authentication # TODO: Check that the device_id matches that in the authentication
# or derive the device_id from the authentication instead. # or derive the device_id from the authentication instead.
try:
body = json.loads(request.content.read()) body = parse_json_object_from_request(request)
except:
raise SynapseError(400, "Invalid key JSON")
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys. # TODO: Validate the JSON to make sure it has the right keys.
@ -173,10 +171,7 @@ class KeyQueryServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id, device_id): def on_POST(self, request, user_id, device_id):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
try: body = parse_json_object_from_request(request)
body = json.loads(request.content.read())
except:
raise SynapseError(400, "Invalid key JSON")
result = yield self.handle_request(body) result = yield self.handle_request(body)
defer.returnValue(result) defer.returnValue(result)
@ -272,10 +267,7 @@ class OneTimeKeyServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm): def on_POST(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
try: body = parse_json_object_from_request(request)
body = json.loads(request.content.read())
except:
raise SynapseError(400, "Invalid key JSON")
result = yield self.handle_request(body) result = yield self.handle_request(body)
defer.returnValue(result) defer.returnValue(result)

View File

@ -37,6 +37,7 @@ class ReceiptRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.receipts_handler = hs.get_handlers().receipts_handler self.receipts_handler = hs.get_handlers().receipts_handler
self.presence_handler = hs.get_handlers().presence_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, receipt_type, event_id): def on_POST(self, request, room_id, receipt_type, event_id):
@ -45,6 +46,8 @@ class ReceiptRestServlet(RestServlet):
if receipt_type != "m.read": if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'") raise SynapseError(400, "Receipt type must be 'm.read'")
yield self.presence_handler.bump_presence_active_time(requester.user)
yield self.receipts_handler.received_client_receipt( yield self.receipts_handler.received_client_receipt(
room_id, room_id,
receipt_type, receipt_type,

View File

@ -17,9 +17,9 @@ from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns, parse_json_dict_from_request from ._base import client_v2_patterns
import logging import logging
import hmac import hmac
@ -34,7 +34,8 @@ from synapse.util.async import run_on_reactor
if hasattr(hmac, "compare_digest"): if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest compare_digest = hmac.compare_digest
else: else:
compare_digest = lambda a, b: a == b def compare_digest(a, b):
return a == b
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -72,7 +73,7 @@ class RegisterRestServlet(RestServlet):
ret = yield self.onEmailTokenRequest(request) ret = yield self.onEmailTokenRequest(request)
defer.returnValue(ret) defer.returnValue(ret)
body = parse_json_dict_from_request(request) body = parse_json_object_from_request(request)
# we do basic sanity checks here because the auth layer will store these # we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us. # in sessions. Pull out the username/password provided to us.
@ -116,15 +117,27 @@ class RegisterRestServlet(RestServlet):
return return
# == Normal User Registration == (everyone else) # == Normal User Registration == (everyone else)
if self.hs.config.disable_registration: if not self.hs.config.enable_registration:
raise SynapseError(403, "Registration has been disabled") raise SynapseError(403, "Registration has been disabled")
guest_access_token = body.get("guest_access_token", None) guest_access_token = body.get("guest_access_token", None)
session_id = self.auth_handler.get_session_id(body)
registered_user_id = None
if session_id:
# if we get a registered user id out of here, it means we previously
# registered a user for this session, so we could just return the
# user here. We carry on and go through the auth checks though,
# for paranoia.
registered_user_id = self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
if desired_username is not None: if desired_username is not None:
yield self.registration_handler.check_username( yield self.registration_handler.check_username(
desired_username, desired_username,
guest_access_token=guest_access_token guest_access_token=guest_access_token,
assigned_user_id=registered_user_id,
) )
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
@ -138,7 +151,7 @@ class RegisterRestServlet(RestServlet):
[LoginType.EMAIL_IDENTITY] [LoginType.EMAIL_IDENTITY]
] ]
authed, result, params = yield self.auth_handler.check_auth( authed, result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request) flows, body, self.hs.get_ip_from_request(request)
) )
@ -146,12 +159,29 @@ class RegisterRestServlet(RestServlet):
defer.returnValue((401, result)) defer.returnValue((401, result))
return return
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session",
registered_user_id
)
access_token = yield self.auth_handler.issue_access_token(registered_user_id)
refresh_token = yield self.auth_handler.issue_refresh_token(
registered_user_id
)
defer.returnValue((200, {
"user_id": registered_user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"refresh_token": refresh_token,
}))
# NB: This may be from the auth handler and NOT from the POST # NB: This may be from the auth handler and NOT from the POST
if 'password' not in params: if 'password' not in params:
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM) raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
desired_username = params.get("username", None) desired_username = params.get("username", None)
new_password = params.get("password", None) new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None)
(user_id, token) = yield self.registration_handler.register( (user_id, token) = yield self.registration_handler.register(
localpart=desired_username, localpart=desired_username,
@ -159,6 +189,12 @@ class RegisterRestServlet(RestServlet):
guest_access_token=guest_access_token, guest_access_token=guest_access_token,
) )
# remember that we've now registered that user account, and with what
# user ID (since the user may not have specified)
self.auth_handler.set_session_data(
session_id, "registered_user_id", user_id
)
if result and LoginType.EMAIL_IDENTITY in result: if result and LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY] threepid = result[LoginType.EMAIL_IDENTITY]
@ -185,7 +221,7 @@ class RegisterRestServlet(RestServlet):
else: else:
logger.info("bind_email not specified: not binding email") logger.info("bind_email not specified: not binding email")
result = self._create_registration_details(user_id, token) result = yield self._create_registration_details(user_id, token)
defer.returnValue((200, result)) defer.returnValue((200, result))
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
@ -196,7 +232,7 @@ class RegisterRestServlet(RestServlet):
(user_id, token) = yield self.registration_handler.appservice_register( (user_id, token) = yield self.registration_handler.appservice_register(
username, as_token username, as_token
) )
defer.returnValue(self._create_registration_details(user_id, token)) defer.returnValue((yield self._create_registration_details(user_id, token)))
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_shared_secret_registration(self, username, password, mac): def _do_shared_secret_registration(self, username, password, mac):
@ -223,18 +259,21 @@ class RegisterRestServlet(RestServlet):
(user_id, token) = yield self.registration_handler.register( (user_id, token) = yield self.registration_handler.register(
localpart=username, password=password localpart=username, password=password
) )
defer.returnValue(self._create_registration_details(user_id, token)) defer.returnValue((yield self._create_registration_details(user_id, token)))
@defer.inlineCallbacks
def _create_registration_details(self, user_id, token): def _create_registration_details(self, user_id, token):
return { refresh_token = yield self.auth_handler.issue_refresh_token(user_id)
defer.returnValue({
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} "refresh_token": refresh_token,
})
@defer.inlineCallbacks @defer.inlineCallbacks
def onEmailTokenRequest(self, request): def onEmailTokenRequest(self, request):
body = parse_json_dict_from_request(request) body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt'] required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = [] absent = []

View File

@ -20,15 +20,16 @@ from synapse.http.servlet import (
) )
from synapse.handlers.sync import SyncConfig from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.events import FrozenEvent
from synapse.events.utils import ( from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_room_id, serialize_event, format_event_for_client_v2_without_room_id,
) )
from synapse.api.filtering import FilterCollection from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.constants import PresenceState
from ._base import client_v2_patterns from ._base import client_v2_patterns
import copy import copy
import itertools
import logging import logging
import ujson as json import ujson as json
@ -82,6 +83,7 @@ class SyncRestServlet(RestServlet):
self.sync_handler = hs.get_handlers().sync_handler self.sync_handler = hs.get_handlers().sync_handler
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
self.presence_handler = hs.get_handlers().presence_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -113,24 +115,24 @@ class SyncRestServlet(RestServlet):
) )
) )
if filter_id and filter_id.startswith('{'): if filter_id:
if filter_id.startswith('{'):
try: try:
filter_object = json.loads(filter_id) filter_object = json.loads(filter_id)
except: except:
raise SynapseError(400, "Invalid filter JSON") raise SynapseError(400, "Invalid filter JSON")
self.filtering._check_valid_filter(filter_object) self.filtering.check_valid_filter(filter_object)
filter = FilterCollection(filter_object) filter = FilterCollection(filter_object)
else: else:
try:
filter = yield self.filtering.get_user_filter( filter = yield self.filtering.get_user_filter(
user.localpart, filter_id user.localpart, filter_id
) )
except: else:
filter = FilterCollection({}) filter = DEFAULT_FILTER_COLLECTION
sync_config = SyncConfig( sync_config = SyncConfig(
user=user, user=user,
filter=filter, filter_collection=filter,
is_guest=requester.is_guest, is_guest=requester.is_guest,
) )
@ -139,38 +141,38 @@ class SyncRestServlet(RestServlet):
else: else:
since_token = None since_token = None
if set_presence == "online": affect_presence = set_presence != PresenceState.OFFLINE
yield self.event_stream_handler.started_stream(user)
try: if affect_presence:
yield self.presence_handler.set_state(user, {"presence": set_presence})
context = yield self.presence_handler.user_syncing(
user.to_string(), affect_presence=affect_presence,
)
with context:
sync_result = yield self.sync_handler.wait_for_sync_for_user( sync_result = yield self.sync_handler.wait_for_sync_for_user(
sync_config, since_token=since_token, timeout=timeout, sync_config, since_token=since_token, timeout=timeout,
full_state=full_state full_state=full_state
) )
finally:
if set_presence == "online":
self.event_stream_handler.stopped_stream(user)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
joined = self.encode_joined( joined = self.encode_joined(
sync_result.joined, filter, time_now, requester.access_token_id sync_result.joined, time_now, requester.access_token_id
) )
invited = self.encode_invited( invited = self.encode_invited(
sync_result.invited, filter, time_now, requester.access_token_id sync_result.invited, time_now, requester.access_token_id
) )
archived = self.encode_archived( archived = self.encode_archived(
sync_result.archived, filter, time_now, requester.access_token_id sync_result.archived, time_now, requester.access_token_id
) )
response_content = { response_content = {
"account_data": self.encode_account_data( "account_data": {"events": sync_result.account_data},
sync_result.account_data, filter, time_now
),
"presence": self.encode_presence( "presence": self.encode_presence(
sync_result.presence, filter, time_now sync_result.presence, time_now
), ),
"rooms": { "rooms": {
"join": joined, "join": joined,
@ -182,24 +184,20 @@ class SyncRestServlet(RestServlet):
defer.returnValue((200, response_content)) defer.returnValue((200, response_content))
def encode_presence(self, events, filter, time_now): def encode_presence(self, events, time_now):
formatted = [] formatted = []
for event in events: for event in events:
event = copy.deepcopy(event) event = copy.deepcopy(event)
event['sender'] = event['content'].pop('user_id') event['sender'] = event['content'].pop('user_id')
formatted.append(event) formatted.append(event)
return {"events": filter.filter_presence(formatted)} return {"events": formatted}
def encode_account_data(self, events, filter, time_now): def encode_joined(self, rooms, time_now, token_id):
return {"events": filter.filter_account_data(events)}
def encode_joined(self, rooms, filter, time_now, token_id):
""" """
Encode the joined rooms in a sync result Encode the joined rooms in a sync result
:param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync :param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync
results for rooms this user is joined to results for rooms this user is joined to
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age :param int time_now: current time - used as a baseline for age
calculations calculations
:param int token_id: ID of the user's auth token - used for namespacing :param int token_id: ID of the user's auth token - used for namespacing
@ -211,18 +209,17 @@ class SyncRestServlet(RestServlet):
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = self.encode_room( joined[room.room_id] = self.encode_room(
room, filter, time_now, token_id room, time_now, token_id
) )
return joined return joined
def encode_invited(self, rooms, filter, time_now, token_id): def encode_invited(self, rooms, time_now, token_id):
""" """
Encode the invited rooms in a sync result Encode the invited rooms in a sync result
:param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of :param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of
sync results for rooms this user is joined to sync results for rooms this user is joined to
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age :param int time_now: current time - used as a baseline for age
calculations calculations
:param int token_id: ID of the user's auth token - used for namespacing :param int token_id: ID of the user's auth token - used for namespacing
@ -237,7 +234,9 @@ class SyncRestServlet(RestServlet):
room.invite, time_now, token_id=token_id, room.invite, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_room_id, event_format=format_event_for_client_v2_without_room_id,
) )
invited_state = invite.get("unsigned", {}).pop("invite_room_state", []) unsigned = dict(invite.get("unsigned", {}))
invite["unsigned"] = unsigned
invited_state = list(unsigned.pop("invite_room_state", []))
invited_state.append(invite) invited_state.append(invite)
invited[room.room_id] = { invited[room.room_id] = {
"invite_state": {"events": invited_state} "invite_state": {"events": invited_state}
@ -245,13 +244,12 @@ class SyncRestServlet(RestServlet):
return invited return invited
def encode_archived(self, rooms, filter, time_now, token_id): def encode_archived(self, rooms, time_now, token_id):
""" """
Encode the archived rooms in a sync result Encode the archived rooms in a sync result
:param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of :param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of
sync results for rooms this user is joined to sync results for rooms this user is joined to
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age :param int time_now: current time - used as a baseline for age
calculations calculations
:param int token_id: ID of the user's auth token - used for namespacing :param int token_id: ID of the user's auth token - used for namespacing
@ -263,17 +261,16 @@ class SyncRestServlet(RestServlet):
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = self.encode_room( joined[room.room_id] = self.encode_room(
room, filter, time_now, token_id, joined=False room, time_now, token_id, joined=False
) )
return joined return joined
@staticmethod @staticmethod
def encode_room(room, filter, time_now, token_id, joined=True): def encode_room(room, time_now, token_id, joined=True):
""" """
:param JoinedSyncResult|ArchivedSyncResult room: sync result for a :param JoinedSyncResult|ArchivedSyncResult room: sync result for a
single room single room
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age :param int time_now: current time - used as a baseline for age
calculations calculations
:param int token_id: ID of the user's auth token - used for namespacing :param int token_id: ID of the user's auth token - used for namespacing
@ -292,19 +289,23 @@ class SyncRestServlet(RestServlet):
) )
state_dict = room.state state_dict = room.state
timeline_events = filter.filter_room_timeline(room.timeline.events) timeline_events = room.timeline.events
state_dict = SyncRestServlet._rollback_state_for_timeline( state_events = state_dict.values()
state_dict, timeline_events)
state_events = filter.filter_room_state(state_dict.values()) for event in itertools.chain(state_events, timeline_events):
# We've had bug reports that events were coming down under the
# wrong room.
if event.room_id != room.room_id:
logger.warn(
"Event %r is under room %r instead of %r",
event.event_id, room.room_id, event.room_id,
)
serialized_state = [serialize(e) for e in state_events] serialized_state = [serialize(e) for e in state_events]
serialized_timeline = [serialize(e) for e in timeline_events] serialized_timeline = [serialize(e) for e in timeline_events]
account_data = filter.filter_room_account_data( account_data = room.account_data
room.account_data
)
result = { result = {
"timeline": { "timeline": {
@ -317,85 +318,12 @@ class SyncRestServlet(RestServlet):
} }
if joined: if joined:
ephemeral_events = filter.filter_room_ephemeral(room.ephemeral) ephemeral_events = room.ephemeral
result["ephemeral"] = {"events": ephemeral_events} result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications result["unread_notifications"] = room.unread_notifications
return result return result
@staticmethod
def _rollback_state_for_timeline(state, timeline):
"""
Wind the state dictionary backwards, so that it represents the
state at the start of the timeline, rather than at the end.
:param dict[(str, str), synapse.events.EventBase] state: the
state dictionary. Will be updated to the state before the timeline.
:param list[synapse.events.EventBase] timeline: the event timeline
:return: updated state dictionary
"""
logger.debug("Processing state dict %r; timeline %r", state,
[e.get_dict() for e in timeline])
result = state.copy()
for timeline_event in reversed(timeline):
if not timeline_event.is_state():
continue
event_key = (timeline_event.type, timeline_event.state_key)
logger.debug("Considering %s for removal", event_key)
state_event = result.get(event_key)
if (state_event is None or
state_event.event_id != timeline_event.event_id):
# the event in the timeline isn't present in the state
# dictionary.
#
# the most likely cause for this is that there was a fork in
# the event graph, and the state is no longer valid. Really,
# the event shouldn't be in the timeline. We're going to ignore
# it for now, however.
logger.warn("Found state event %r in timeline which doesn't "
"match state dictionary", timeline_event)
continue
prev_event_id = timeline_event.unsigned.get("replaces_state", None)
prev_content = timeline_event.unsigned.get('prev_content')
prev_sender = timeline_event.unsigned.get('prev_sender')
# Empircally it seems possible for the event to have a
# "replaces_state" key but not a prev_content or prev_sender
# markjh conjectures that it could be due to the server not
# having a copy of that event.
# If this is the case the we ignore the previous event. This will
# cause the displayname calculations on the client to be incorrect
if prev_event_id is None or not prev_content or not prev_sender:
logger.debug(
"Removing %r from the state dict, as it is missing"
" prev_content (prev_event_id=%r)",
timeline_event.event_id, prev_event_id
)
del result[event_key]
else:
logger.debug(
"Replacing %r with %r in state dict",
timeline_event.event_id, prev_event_id
)
result[event_key] = FrozenEvent({
"type": timeline_event.type,
"state_key": timeline_event.state_key,
"content": prev_content,
"sender": prev_sender,
"event_id": prev_event_id,
"room_id": timeline_event.room_id,
})
logger.debug("New value: %r", result.get(event_key))
return result
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
SyncRestServlet(hs).register(http_server) SyncRestServlet(hs).register(http_server)

View File

@ -15,15 +15,13 @@
from ._base import client_v2_patterns from ._base import client_v2_patterns
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError
from twisted.internet import defer from twisted.internet import defer
import logging import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -72,15 +70,11 @@ class TagServlet(RestServlet):
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.") raise AuthError(403, "Cannot add tags for other users.")
try: body = parse_json_object_from_request(request)
content_bytes = request.content.read()
body = json.loads(content_bytes)
except:
raise SynapseError(400, "Invalid tag JSON")
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
yield self.notifier.on_new_event( self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id] "account_data_key", max_id, users=[user_id]
) )
@ -94,7 +88,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
yield self.notifier.on_new_event( self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id] "account_data_key", max_id, users=[user_id]
) )

View File

@ -16,9 +16,9 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns, parse_json_dict_from_request from ._base import client_v2_patterns
class TokenRefreshRestServlet(RestServlet): class TokenRefreshRestServlet(RestServlet):
@ -35,7 +35,7 @@ class TokenRefreshRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
body = parse_json_dict_from_request(request) body = parse_json_object_from_request(request)
try: try:
old_refresh_token = body["refresh_token"] old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_handlers().auth_handler auth_handler = self.hs.get_handlers().auth_handler

View File

@ -26,9 +26,7 @@ class VersionsRestServlet(RestServlet):
def on_GET(self, request): def on_GET(self, request):
return (200, { return (200, {
"versions": [ "versions": ["r0.0.1"]
"r0.0.1",
]
}) })

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from synapse.http.server import request_handler, respond_with_json_bytes from synapse.http.server import request_handler, respond_with_json_bytes
from synapse.http.servlet import parse_integer from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from twisted.web.resource import Resource from twisted.web.resource import Resource
@ -22,7 +22,6 @@ from twisted.internet import defer
from io import BytesIO from io import BytesIO
import json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -126,14 +125,7 @@ class RemoteKey(Resource):
@request_handler @request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def async_render_POST(self, request): def async_render_POST(self, request):
try: content = parse_json_object_from_request(request)
content = json.loads(request.content.read())
if type(content) != dict:
raise ValueError()
except ValueError:
raise SynapseError(
400, "Content must be JSON object.", errcode=Codes.NOT_JSON
)
query = content["server_keys"] query = content["server_keys"]

Some files were not shown because too many files have changed in this diff Show More