Merge branch 'release-v0.14.0' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2016-03-30 12:36:40 +01:00
commit 5fbdf2bcec
180 changed files with 6519 additions and 5436 deletions

View File

@ -1,3 +1,68 @@
Changes in synapse v0.14.0 (2016-03-30)
=======================================
No changes from v0.14.0-rc2
Changes in synapse v0.14.0-rc2 (2016-03-23)
===========================================
Features:
* Add published room list API (PR #657)
Changes:
* Change various caches to consume less memory (PR #656, #658, #660, #662,
#663, #665)
* Allow rooms to be published without requiring an alias (PR #664)
* Intern common strings in caches to reduce memory footprint (#666)
Bug fixes:
* Fix reject invites over federation (PR #646)
* Fix bug where registration was not idempotent (PR #649)
* Update aliases event after deleting aliases (PR #652)
* Fix unread notification count, which was sometimes wrong (PR #661)
Changes in synapse v0.14.0-rc1 (2016-03-14)
===========================================
Features:
* Add event_id to response to state event PUT (PR #581)
* Allow guest users access to messages in rooms they have joined (PR #587)
* Add config for what state is included in a room invite (PR #598)
* Send the inviter's member event in room invite state (PR #607)
* Add error codes for malformed/bad JSON in /login (PR #608)
* Add support for changing the actions for default rules (PR #609)
* Add environment variable SYNAPSE_CACHE_FACTOR, default it to 0.1 (PR #612)
* Add ability for alias creators to delete aliases (PR #614)
* Add profile information to invites (PR #624)
Changes:
* Enforce user_id exclusivity for AS registrations (PR #572)
* Make adding push rules idempotent (PR #587)
* Improve presence performance (PR #582, #586)
* Change presence semantics for ``last_active_ago`` (PR #582, #586)
* Don't allow ``m.room.create`` to be changed (PR #596)
* Add 800x600 to default list of valid thumbnail sizes (PR #616)
* Always include kicks and bans in full /sync (PR #625)
* Send history visibility on boundary changes (PR #626)
* Register endpoint now returns a refresh_token (PR #637)
Bug fixes:
* Fix bug where we returned incorrect state in /sync (PR #573)
* Always return a JSON object from push rule API (PR #606)
* Fix bug where registering without a user id sometimes failed (PR #610)
* Report size of ExpiringCache in cache size metrics (PR #611)
* Fix rejection of invites to empty rooms (PR #615)
* Fix usage of ``bcrypt`` to not use ``checkpw`` (PR #619)
* Pin ``pysaml2`` dependency (PR #634)
* Fix bug in ``/sync`` where timeline order was incorrect for backfilled events
(PR #635)
Changes in synapse v0.13.3 (2016-02-11) Changes in synapse v0.13.3 (2016-02-11)
======================================= =======================================

View File

@ -21,5 +21,6 @@ recursive-include synapse/static *.html
recursive-include synapse/static *.js recursive-include synapse/static *.js
exclude jenkins.sh exclude jenkins.sh
exclude jenkins*.sh
prune demo/etc prune demo/etc

View File

@ -525,7 +525,6 @@ Logging In To An Existing Account
Just enter the ``@localpart:my.domain.here`` Matrix user ID and password into Just enter the ``@localpart:my.domain.here`` Matrix user ID and password into
the form and click the Login button. the form and click the Login button.
Identity Servers Identity Servers
================ ================
@ -545,6 +544,26 @@ as the primary means of identity and E2E encryption is not complete. As such,
we are running a single identity server (https://matrix.org) at the current we are running a single identity server (https://matrix.org) at the current
time. time.
Password reset
==============
If a user has registered an email address to their account using an identity
server, they can request a password-reset token via clients such as Vector.
A manual password reset can be done via direct database access as follows.
First calculate the hash of the new password:
$ source ~/.synapse/bin/activate
$ ./scripts/hash_password
Password:
Confirm password:
$2a$12$xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
Then update the `users` table in the database:
UPDATE users SET password_hash='$2a$12$xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
WHERE name='@test:test.com';
Where's the spec?! Where's the spec?!
================== ==================
@ -566,3 +585,20 @@ Building internal API documentation::
python setup.py build_sphinx python setup.py build_sphinx
Halp!! Synapse eats all my RAM!
===============================
Synapse's architecture is quite RAM hungry currently - we deliberately
cache a lot of recent room data and metadata in RAM in order to speed up
common requests. We'll improve this in future, but for now the easiest
way to either reduce the RAM usage (at the risk of slowing things down)
is to set the almost-undocumented ``SYNAPSE_CACHE_FACTOR`` environment
variable. Roughly speaking, a SYNAPSE_CACHE_FACTOR of 1.0 will max out
at around 3-4GB of resident memory - this is what we currently run the
matrix.org on. The default setting is currently 0.1, which is probably
around a ~700MB footprint. You can dial it down further to 0.02 if
desired, which targets roughly ~512MB. Conversely you can dial it up if
you need performance for lots of users and have a box with a lot of RAM.

22
jenkins-flake8.sh Executable file
View File

@ -0,0 +1,22 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
export PEP8SUFFIX="--output-file=violations.flake8.log"
rm .coverage* || echo "No coverage files to remove"
tox -e packaging -e pep8

61
jenkins-postgres.sh Executable file
View File

@ -0,0 +1,61 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
$TOX_BIN/pip install psycopg2
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000}
./jenkins/prep_sytest_for_postgres.sh
echo >&2 "Running sytest with PostgreSQL";
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--port-base $PORT_BASE
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

55
jenkins-sqlite.sh Executable file
View File

@ -0,0 +1,55 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8500}
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--port-base $PORT_BASE
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

25
jenkins-unittests.sh Executable file
View File

@ -0,0 +1,25 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox -e py27

View File

@ -1,6 +1,11 @@
#!/bin/bash -eu #!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml # Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit" export TRIAL_FLAGS="--reporter=subunit"

View File

@ -86,9 +86,12 @@ def used_names(prefix, item, defs, names):
for name, funcs in defs.get('class', {}).items(): for name, funcs in defs.get('class', {}).items():
used_names(prefix + name + ".", name, funcs, names) used_names(prefix + name + ".", name, funcs, names)
path = prefix.rstrip('.')
for used in defs.get('uses', ()): for used in defs.get('uses', ()):
if used in names: if used in names:
names[used].setdefault('used', {}).setdefault(item, []).append(prefix.rstrip('.')) if item:
names[item].setdefault('uses', []).append(used)
names[used].setdefault('used', {}).setdefault(item, []).append(path)
if __name__ == '__main__': if __name__ == '__main__':
@ -113,6 +116,10 @@ if __name__ == '__main__':
"--referrers", default=0, type=int, "--referrers", default=0, type=int,
help="Include referrers up to the given depth" help="Include referrers up to the given depth"
) )
parser.add_argument(
"--referred", default=0, type=int,
help="Include referred down to the given depth"
)
parser.add_argument( parser.add_argument(
"--format", default="yaml", "--format", default="yaml",
help="Output format, one of 'yaml' or 'dot'" help="Output format, one of 'yaml' or 'dot'"
@ -161,6 +168,20 @@ if __name__ == '__main__':
continue continue
result[name] = definition result[name] = definition
referred_depth = args.referred
referred = set()
while referred_depth:
referred_depth -= 1
for entry in result.values():
for uses in entry.get("uses", ()):
referred.add(uses)
for name, definition in names.items():
if not name in referred:
continue
if ignore and any(pattern.match(name) for pattern in ignore):
continue
result[name] = definition
if args.format == 'yaml': if args.format == 'yaml':
yaml.dump(result, sys.stdout, default_flow_style=False) yaml.dump(result, sys.stdout, default_flow_style=False)
elif args.format == 'dot': elif args.format == 'dot':

View File

@ -0,0 +1,67 @@
import requests
import collections
import sys
import time
import json
Entry = collections.namedtuple("Entry", "name position rows")
ROW_TYPES = {}
def row_type_for_columns(name, column_names):
column_names = tuple(column_names)
row_type = ROW_TYPES.get((name, column_names))
if row_type is None:
row_type = collections.namedtuple(name, column_names)
ROW_TYPES[(name, column_names)] = row_type
return row_type
def parse_response(content):
streams = json.loads(content)
result = {}
for name, value in streams.items():
row_type = row_type_for_columns(name, value["field_names"])
position = value["position"]
rows = [row_type(*row) for row in value["rows"]]
result[name] = Entry(name, position, rows)
return result
def replicate(server, streams):
return parse_response(requests.get(
server + "/_synapse/replication",
verify=False,
params=streams
).content)
def main():
server = sys.argv[1]
streams = None
while not streams:
try:
streams = {
row.name: row.position
for row in replicate(server, {"streams":"-1"})["streams"].rows
}
except requests.exceptions.ConnectionError as e:
time.sleep(0.1)
while True:
try:
results = replicate(server, streams)
except:
sys.stdout.write("connection_lost("+ repr(streams) + ")\n")
break
for update in results.values():
for row in update.rows:
sys.stdout.write(repr(row) + "\n")
streams[update.name] = update.position
if __name__=='__main__':
main()

View File

@ -1 +0,0 @@
perl -MCrypt::Random -MCrypt::Eksblowfish::Bcrypt -e 'print Crypt::Eksblowfish::Bcrypt::bcrypt("secret", "\$2\$12\$" . Crypt::Eksblowfish::Bcrypt::en_base64(Crypt::Random::makerandom_octet(Length=>16)))."\n"'

39
scripts/hash_password Executable file
View File

@ -0,0 +1,39 @@
#!/usr/bin/env python
import argparse
import bcrypt
import getpass
bcrypt_rounds=12
def prompt_for_pass():
password = getpass.getpass("Password: ")
if not password:
raise Exception("Password cannot be blank.")
confirm_password = getpass.getpass("Confirm password: ")
if password != confirm_password:
raise Exception("Passwords do not match.")
return password
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Calculate the hash of a new password, so that passwords"
" can be reset")
parser.add_argument(
"-p", "--password",
default=None,
help="New password for user. Will prompt if omitted.",
)
args = parser.parse_args()
password = args.password
if not password:
password = prompt_for_pass()
print bcrypt.hashpw(password, bcrypt.gensalt(bcrypt_rounds))

View File

@ -309,8 +309,8 @@ class Porter(object):
**self.postgres_config["args"] **self.postgres_config["args"]
) )
sqlite_engine = create_engine("sqlite3") sqlite_engine = create_engine(FakeConfig(sqlite_config))
postgres_engine = create_engine("psycopg2") postgres_engine = create_engine(FakeConfig(postgres_config))
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
self.postgres_store = Store(postgres_db_pool, postgres_engine) self.postgres_store = Store(postgres_db_pool, postgres_engine)
@ -792,3 +792,8 @@ if __name__ == "__main__":
if end_error_exec_info: if end_error_exec_info:
exc_type, exc_value, exc_traceback = end_error_exec_info exc_type, exc_value, exc_traceback = end_error_exec_info
traceback.print_exception(exc_type, exc_value, exc_traceback) traceback.print_exception(exc_type, exc_value, exc_traceback)
class FakeConfig:
def __init__(self, database_config):
self.database_config = database_config

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.13.3" __version__ = "0.14.0"

View File

@ -434,31 +434,46 @@ class Auth(object):
if event.user_id != invite_event.user_id: if event.user_id != invite_event.user_id:
return False return False
try:
public_key = invite_event.content["public_key"]
if signed["mxid"] != event.state_key:
return False
if signed["token"] != token:
return False
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"):
return False
verify_key = decode_verify_key_bytes(
key_name,
decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
# We got the public key from the invite, so we know that the if signed["mxid"] != event.state_key:
# correct server signed the signed bundle.
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True
return False return False
except (KeyError, SignatureVerifyException,): if signed["token"] != token:
return False return False
for public_key_object in self.get_public_keys(invite_event):
public_key = public_key_object["public_key"]
try:
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"):
continue
verify_key = decode_verify_key_bytes(
key_name,
decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
# We got the public key from the invite, so we know that the
# correct server signed the signed bundle.
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True
except (KeyError, SignatureVerifyException,):
continue
return False
def get_public_keys(self, invite_event):
public_keys = []
if "public_key" in invite_event.content:
o = {
"public_key": invite_event.content["public_key"],
}
if "key_validity_url" in invite_event.content:
o["key_validity_url"] = invite_event.content["key_validity_url"]
public_keys.append(o)
public_keys.extend(invite_event.content.get("public_keys", []))
return public_keys
def _get_power_level_event(self, auth_events): def _get_power_level_event(self, auth_events):
key = (EventTypes.PowerLevels, "", ) key = (EventTypes.PowerLevels, "", )
return auth_events.get(key) return auth_events.get(key)
@ -519,7 +534,7 @@ class Auth(object):
) )
access_token = request.args["access_token"][0] access_token = request.args["access_token"][0]
user_info = yield self._get_user_by_access_token(access_token) user_info = yield self.get_user_by_access_token(access_token)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
is_guest = user_info["is_guest"] is_guest = user_info["is_guest"]
@ -580,7 +595,7 @@ class Auth(object):
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_user_by_access_token(self, token): def get_user_by_access_token(self, token):
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:
@ -799,17 +814,16 @@ class Auth(object):
return auth_ids return auth_ids
@log_function def _get_send_level(self, etype, state_key, auth_events):
def _can_send_event(self, event, auth_events):
key = (EventTypes.PowerLevels, "", ) key = (EventTypes.PowerLevels, "", )
send_level_event = auth_events.get(key) send_level_event = auth_events.get(key)
send_level = None send_level = None
if send_level_event: if send_level_event:
send_level = send_level_event.content.get("events", {}).get( send_level = send_level_event.content.get("events", {}).get(
event.type etype
) )
if send_level is None: if send_level is None:
if hasattr(event, "state_key"): if state_key is not None:
send_level = send_level_event.content.get( send_level = send_level_event.content.get(
"state_default", 50 "state_default", 50
) )
@ -823,6 +837,13 @@ class Auth(object):
else: else:
send_level = 0 send_level = 0
return send_level
@log_function
def _can_send_event(self, event, auth_events):
send_level = self._get_send_level(
event.type, event.get("state_key", None), auth_events
)
user_level = self._get_user_power_level(event.user_id, auth_events) user_level = self._get_user_power_level(event.user_id, auth_events)
if user_level < send_level: if user_level < send_level:
@ -967,3 +988,43 @@ class Auth(object):
"You don't have permission to add ops level greater " "You don't have permission to add ops level greater "
"than your own" "than your own"
) )
@defer.inlineCallbacks
def check_can_change_room_list(self, room_id, user):
"""Check if the user is allowed to edit the room's entry in the
published room list.
Args:
room_id (str)
user (UserID)
"""
is_admin = yield self.is_server_admin(user)
if is_admin:
defer.returnValue(True)
user_id = user.to_string()
yield self.check_joined_room(room_id, user_id)
# We currently require the user is a "moderator" in the room. We do this
# by checking if they would (theoretically) be able to change the
# m.room.aliases events
power_level_event = yield self.state.get_current_state(
room_id, EventTypes.PowerLevels, ""
)
auth_events = {}
if power_level_event:
auth_events[(EventTypes.PowerLevels, "")] = power_level_event
send_level = self._get_send_level(
EventTypes.Aliases, "", auth_events
)
user_level = self._get_user_power_level(user_id, auth_events)
if user_level < send_level:
raise AuthError(
403,
"This server requires you to be a moderator in the room to"
" edit its room list entry"
)

View File

@ -32,7 +32,6 @@ class PresenceState(object):
OFFLINE = u"offline" OFFLINE = u"offline"
UNAVAILABLE = u"unavailable" UNAVAILABLE = u"unavailable"
ONLINE = u"online" ONLINE = u"online"
FREE_FOR_CHAT = u"free_for_chat"
class JoinRules(object): class JoinRules(object):

View File

@ -198,7 +198,10 @@ class Filter(object):
sender = event.get("sender", None) sender = event.get("sender", None)
if not sender: if not sender:
# Presence events have their 'sender' in content.user_id # Presence events have their 'sender' in content.user_id
sender = event.get("content", {}).get("user_id", None) content = event.get("content")
# account_data has been allowed to have non-dict content, so check type first
if isinstance(content, dict):
sender = content.get("user_id")
return self.check_fields( return self.check_fields(
event.get("room_id", None), event.get("room_id", None),

View File

@ -63,6 +63,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer from synapse.federation.transport.server import TransportLayerServer
from synapse import events from synapse import events
@ -169,6 +170,9 @@ class SynapseHomeServer(HomeServer):
if name == "metrics" and self.get_config().enable_metrics: if name == "metrics" and self.get_config().enable_metrics:
resources[METRICS_PREFIX] = MetricsResource(self) resources[METRICS_PREFIX] = MetricsResource(self)
if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationResource(self)
root_resource = create_resource_tree(resources) root_resource = create_resource_tree(resources)
if tls: if tls:
reactor.listenSSL( reactor.listenSSL(
@ -382,7 +386,7 @@ def setup(config_options):
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
database_engine = create_engine(config.database_config["name"]) database_engine = create_engine(config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
hs = SynapseHomeServer( hs = SynapseHomeServer(
@ -718,7 +722,7 @@ def run(hs):
if hs.config.daemonize: if hs.config.daemonize:
if hs.config.print_pidfile: if hs.config.print_pidfile:
print hs.config.pid_file print (hs.config.pid_file)
daemon = Daemonize( daemon = Daemonize(
app="synapse-homeserver", app="synapse-homeserver",

View File

@ -29,13 +29,13 @@ NORMAL = "\x1b[m"
def start(configfile): def start(configfile):
print "Starting ...", print ("Starting ...")
args = SYNAPSE args = SYNAPSE
args.extend(["--daemonize", "-c", configfile]) args.extend(["--daemonize", "-c", configfile])
try: try:
subprocess.check_call(args) subprocess.check_call(args)
print GREEN + "started" + NORMAL print (GREEN + "started" + NORMAL)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print ( print (
RED + RED +
@ -48,7 +48,7 @@ def stop(pidfile):
if os.path.exists(pidfile): if os.path.exists(pidfile):
pid = int(open(pidfile).read()) pid = int(open(pidfile).read())
os.kill(pid, signal.SIGTERM) os.kill(pid, signal.SIGTERM)
print GREEN + "stopped" + NORMAL print (GREEN + "stopped" + NORMAL)
def main(): def main():

View File

@ -28,7 +28,7 @@ if __name__ == "__main__":
sys.stderr.write("\n" + e.message + "\n") sys.stderr.write("\n" + e.message + "\n")
sys.exit(1) sys.exit(1)
print getattr(config, key) print (getattr(config, key))
sys.exit(0) sys.exit(0)
else: else:
sys.stderr.write("Unknown command %r\n" % (action,)) sys.stderr.write("Unknown command %r\n" % (action,))

View File

@ -104,7 +104,7 @@ class Config(object):
dir_path = cls.abspath(dir_path) dir_path = cls.abspath(dir_path)
try: try:
os.makedirs(dir_path) os.makedirs(dir_path)
except OSError, e: except OSError as e:
if e.errno != errno.EEXIST: if e.errno != errno.EEXIST:
raise raise
if not os.path.isdir(dir_path): if not os.path.isdir(dir_path):

40
synapse/config/api.py Normal file
View File

@ -0,0 +1,40 @@
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
from synapse.api.constants import EventTypes
class ApiConfig(Config):
def read_config(self, config):
self.room_invite_state_types = config.get("room_invite_state_types", [
EventTypes.JoinRules,
EventTypes.CanonicalAlias,
EventTypes.RoomAvatar,
EventTypes.Name,
])
def default_config(cls, **kwargs):
return """\
## API Configuration ##
# A list of event types that will be included in the room_invite_state
room_invite_state_types:
- "{JoinRules}"
- "{CanonicalAlias}"
- "{RoomAvatar}"
- "{Name}"
""".format(**vars(EventTypes))

View File

@ -23,6 +23,7 @@ from .captcha import CaptchaConfig
from .voip import VoipConfig from .voip import VoipConfig
from .registration import RegistrationConfig from .registration import RegistrationConfig
from .metrics import MetricsConfig from .metrics import MetricsConfig
from .api import ApiConfig
from .appservice import AppServiceConfig from .appservice import AppServiceConfig
from .key import KeyConfig from .key import KeyConfig
from .saml2 import SAML2Config from .saml2 import SAML2Config
@ -32,7 +33,7 @@ from .password import PasswordConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
PasswordConfig,): PasswordConfig,):
pass pass

View File

@ -37,6 +37,10 @@ class RegistrationConfig(Config):
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"] self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
self.allow_guest_access = config.get("allow_guest_access", False) self.allow_guest_access = config.get("allow_guest_access", False)
self.invite_3pid_guest = (
self.allow_guest_access and config.get("invite_3pid_guest", False)
)
def default_config(self, **kwargs): def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50) registration_shared_secret = random_string_with_symbols(50)

View File

@ -97,4 +97,7 @@ class ContentRepositoryConfig(Config):
- width: 640 - width: 640
height: 480 height: 480
method: scale method: scale
- width: 800
height: 600
method: scale
""" % locals() """ % locals()

View File

@ -36,6 +36,7 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
factory = SynapseKeyClientFactory() factory = SynapseKeyClientFactory()
factory.path = path factory.path = path
factory.host = server_name
endpoint = matrix_federation_endpoint( endpoint = matrix_federation_endpoint(
reactor, server_name, ssl_context_factory, timeout=30 reactor, server_name, ssl_context_factory, timeout=30
) )
@ -81,6 +82,8 @@ class SynapseKeyClientProtocol(HTTPClient):
self.host = self.transport.getHost() self.host = self.transport.getHost()
logger.debug("Connected to %s", self.host) logger.debug("Connected to %s", self.host)
self.sendCommand(b"GET", self.path) self.sendCommand(b"GET", self.path)
if self.host:
self.sendHeader(b"Host", self.host)
self.endHeaders() self.endHeaders()
self.timer = reactor.callLater( self.timer = reactor.callLater(
self.timeout, self.timeout,

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.util.frozenutils import freeze from synapse.util.frozenutils import freeze
from synapse.util.caches import intern_dict
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
@ -140,6 +141,10 @@ class FrozenEvent(EventBase):
unsigned = dict(event_dict.pop("unsigned", {})) unsigned = dict(event_dict.pop("unsigned", {}))
# We intern these strings because they turn up a lot (especially when
# caching).
event_dict = intern_dict(event_dict)
if USE_FROZEN_DICTS: if USE_FROZEN_DICTS:
frozen_dict = freeze(event_dict) frozen_dict = freeze(event_dict)
else: else:
@ -168,5 +173,7 @@ class FrozenEvent(EventBase):
def __repr__(self): def __repr__(self):
return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % ( return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % (
self.event_id, self.type, self.get("state_key", None), self.get("event_id", None),
self.get("type", None),
self.get("state_key", None),
) )

View File

@ -114,7 +114,7 @@ class FederationClient(FederationBase):
@log_function @log_function
def make_query(self, destination, query_type, args, def make_query(self, destination, query_type, args,
retry_on_dns_fail=True): retry_on_dns_fail=False):
"""Sends a federation Query to a remote homeserver of the given type """Sends a federation Query to a remote homeserver of the given type
and arguments. and arguments.
@ -418,6 +418,7 @@ class FederationClient(FederationBase):
"Failed to make_%s via %s: %s", "Failed to make_%s via %s: %s",
membership, destination, e.message membership, destination, e.message
) )
raise
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")

View File

@ -137,8 +137,8 @@ class FederationServer(FederationBase):
logger.exception("Failed to handle PDU") logger.exception("Failed to handle PDU")
if hasattr(transaction, "edus"): if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]: for edu in (Edu(**x) for x in transaction.edus):
self.received_edu( yield self.received_edu(
transaction.origin, transaction.origin,
edu.edu_type, edu.edu_type,
edu.content edu.content
@ -161,11 +161,17 @@ class FederationServer(FederationBase):
) )
defer.returnValue((200, response)) defer.returnValue((200, response))
@defer.inlineCallbacks
def received_edu(self, origin, edu_type, content): def received_edu(self, origin, edu_type, content):
received_edus_counter.inc() received_edus_counter.inc()
if edu_type in self.edu_handlers: if edu_type in self.edu_handlers:
self.edu_handlers[edu_type](origin, content) try:
yield self.edu_handlers[edu_type](origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e:
logger.exception("Failed to handle edu %r", edu_type, e)
else: else:
logger.warn("Received EDU of type %s with no handler", edu_type) logger.warn("Received EDU of type %s with no handler", edu_type)
@ -525,7 +531,6 @@ class FederationServer(FederationBase):
yield self.handler.on_receive_pdu( yield self.handler.on_receive_pdu(
origin, origin,
pdu, pdu,
backfilled=False,
state=state, state=state,
auth_chain=auth_chain, auth_chain=auth_chain,
) )
@ -543,8 +548,19 @@ class FederationServer(FederationBase):
return event return event
@defer.inlineCallbacks @defer.inlineCallbacks
def exchange_third_party_invite(self, invite): def exchange_third_party_invite(
ret = yield self.handler.exchange_third_party_invite(invite) self,
sender_user_id,
target_user_id,
room_id,
signed,
):
ret = yield self.handler.exchange_third_party_invite(
sender_user_id,
target_user_id,
room_id,
signed,
)
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -160,6 +160,7 @@ class TransportLayerClient(object):
path=path, path=path,
args=args, args=args,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
timeout=10000,
) )
defer.returnValue(content) defer.returnValue(content)

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import parse_json_object_from_request
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
import functools import functools
@ -174,7 +175,7 @@ class BaseFederationServlet(object):
class FederationSendServlet(BaseFederationServlet): class FederationSendServlet(BaseFederationServlet):
PATH = "/send/([^/]*)/" PATH = "/send/(?P<transaction_id>[^/]*)/"
def __init__(self, handler, server_name, **kwargs): def __init__(self, handler, server_name, **kwargs):
super(FederationSendServlet, self).__init__( super(FederationSendServlet, self).__init__(
@ -249,7 +250,7 @@ class FederationPullServlet(BaseFederationServlet):
class FederationEventServlet(BaseFederationServlet): class FederationEventServlet(BaseFederationServlet):
PATH = "/event/([^/]*)/" PATH = "/event/(?P<event_id>[^/]*)/"
# This is when someone asks for a data item for a given server data_id pair. # This is when someone asks for a data item for a given server data_id pair.
def on_GET(self, origin, content, query, event_id): def on_GET(self, origin, content, query, event_id):
@ -257,7 +258,7 @@ class FederationEventServlet(BaseFederationServlet):
class FederationStateServlet(BaseFederationServlet): class FederationStateServlet(BaseFederationServlet):
PATH = "/state/([^/]*)/" PATH = "/state/(?P<context>[^/]*)/"
# This is when someone asks for all data for a given context. # This is when someone asks for all data for a given context.
def on_GET(self, origin, content, query, context): def on_GET(self, origin, content, query, context):
@ -269,7 +270,7 @@ class FederationStateServlet(BaseFederationServlet):
class FederationBackfillServlet(BaseFederationServlet): class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/([^/]*)/" PATH = "/backfill/(?P<context>[^/]*)/"
def on_GET(self, origin, content, query, context): def on_GET(self, origin, content, query, context):
versions = query["v"] versions = query["v"]
@ -284,7 +285,7 @@ class FederationBackfillServlet(BaseFederationServlet):
class FederationQueryServlet(BaseFederationServlet): class FederationQueryServlet(BaseFederationServlet):
PATH = "/query/([^/]*)" PATH = "/query/(?P<query_type>[^/]*)"
# This is when we receive a server-server Query # This is when we receive a server-server Query
def on_GET(self, origin, content, query, query_type): def on_GET(self, origin, content, query, query_type):
@ -295,7 +296,7 @@ class FederationQueryServlet(BaseFederationServlet):
class FederationMakeJoinServlet(BaseFederationServlet): class FederationMakeJoinServlet(BaseFederationServlet):
PATH = "/make_join/([^/]*)/([^/]*)" PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id): def on_GET(self, origin, content, query, context, user_id):
@ -304,7 +305,7 @@ class FederationMakeJoinServlet(BaseFederationServlet):
class FederationMakeLeaveServlet(BaseFederationServlet): class FederationMakeLeaveServlet(BaseFederationServlet):
PATH = "/make_leave/([^/]*)/([^/]*)" PATH = "/make_leave/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id): def on_GET(self, origin, content, query, context, user_id):
@ -313,7 +314,7 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
class FederationSendLeaveServlet(BaseFederationServlet): class FederationSendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/([^/]*)/([^/]*)" PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<txid>[^/]*)"
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, origin, content, query, room_id, txid): def on_PUT(self, origin, content, query, room_id, txid):
@ -322,14 +323,14 @@ class FederationSendLeaveServlet(BaseFederationServlet):
class FederationEventAuthServlet(BaseFederationServlet): class FederationEventAuthServlet(BaseFederationServlet):
PATH = "/event_auth/([^/]*)/([^/]*)" PATH = "/event_auth(?P<context>[^/]*)/(?P<event_id>[^/]*)"
def on_GET(self, origin, content, query, context, event_id): def on_GET(self, origin, content, query, context, event_id):
return self.handler.on_event_auth(origin, context, event_id) return self.handler.on_event_auth(origin, context, event_id)
class FederationSendJoinServlet(BaseFederationServlet): class FederationSendJoinServlet(BaseFederationServlet):
PATH = "/send_join/([^/]*)/([^/]*)" PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, origin, content, query, context, event_id): def on_PUT(self, origin, content, query, context, event_id):
@ -340,7 +341,7 @@ class FederationSendJoinServlet(BaseFederationServlet):
class FederationInviteServlet(BaseFederationServlet): class FederationInviteServlet(BaseFederationServlet):
PATH = "/invite/([^/]*)/([^/]*)" PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, origin, content, query, context, event_id): def on_PUT(self, origin, content, query, context, event_id):
@ -351,7 +352,7 @@ class FederationInviteServlet(BaseFederationServlet):
class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet): class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
PATH = "/exchange_third_party_invite/([^/]*)" PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, origin, content, query, room_id): def on_PUT(self, origin, content, query, room_id):
@ -380,7 +381,7 @@ class FederationClientKeysClaimServlet(BaseFederationServlet):
class FederationQueryAuthServlet(BaseFederationServlet): class FederationQueryAuthServlet(BaseFederationServlet):
PATH = "/query_auth/([^/]*)/([^/]*)" PATH = "/query_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, origin, content, query, context, event_id): def on_POST(self, origin, content, query, context, event_id):
@ -393,7 +394,7 @@ class FederationQueryAuthServlet(BaseFederationServlet):
class FederationGetMissingEventsServlet(BaseFederationServlet): class FederationGetMissingEventsServlet(BaseFederationServlet):
# TODO(paul): Why does this path alone end with "/?" optional? # TODO(paul): Why does this path alone end with "/?" optional?
PATH = "/get_missing_events/([^/]*)/?" PATH = "/get_missing_events/(?P<room_id>[^/]*)/?"
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, origin, content, query, room_id): def on_POST(self, origin, content, query, room_id):
@ -419,13 +420,22 @@ class On3pidBindServlet(BaseFederationServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
content_bytes = request.content.read() content = parse_json_object_from_request(request)
content = json.loads(content_bytes)
if "invites" in content: if "invites" in content:
last_exception = None last_exception = None
for invite in content["invites"]: for invite in content["invites"]:
try: try:
yield self.handler.exchange_third_party_invite(invite) if "signed" not in invite or "token" not in invite["signed"]:
message = ("Rejecting received notification of third-"
"party invite without signed: %s" % (invite,))
logger.info(message)
raise SynapseError(400, message)
yield self.handler.exchange_third_party_invite(
invite["sender"],
invite["mxid"],
invite["room_id"],
invite["signed"],
)
except Exception as e: except Exception as e:
last_exception = e last_exception = e
if last_exception: if last_exception:

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError, SynapseError, AuthError from synapse.api.errors import LimitExceededError, SynapseError, AuthError
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, RoomAlias from synapse.types import UserID, RoomAlias, Requester
from synapse.push.action_generator import ActionGenerator from synapse.push.action_generator import ActionGenerator
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
@ -29,6 +29,14 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VISIBILITY_PRIORITY = (
"world_readable",
"shared",
"invited",
"joined",
)
class BaseHandler(object): class BaseHandler(object):
""" """
Common base class for the event handlers. Common base class for the event handlers.
@ -53,9 +61,15 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_clients(self, user_tuples, events, event_id_to_state): def filter_events_for_clients(self, user_tuples, events, event_id_to_state):
""" Returns dict of user_id -> list of events that user is allowed to """ Returns dict of user_id -> list of events that user is allowed to
see. see.
:param (str, bool) user_tuples: (user id, is_peeking) for each
user to be checked. is_peeking should be true if:
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the given
events
""" """
forgotten = yield defer.gatherResults([ forgotten = yield defer.gatherResults([
self.store.who_forgot_in_room( self.store.who_forgot_in_room(
@ -72,18 +86,38 @@ class BaseHandler(object):
def allowed(event, user_id, is_peeking): def allowed(event, user_id, is_peeking):
state = event_id_to_state[event.event_id] state = event_id_to_state[event.event_id]
# get the room_visibility at the time of the event.
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event: if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared") visibility = visibility_event.content.get("history_visibility", "shared")
else: else:
visibility = "shared" visibility = "shared"
if visibility not in VISIBILITY_PRIORITY:
visibility = "shared"
# if it was world_readable, it's easy: everyone can read it
if visibility == "world_readable": if visibility == "world_readable":
return True return True
if is_peeking: # Always allow history visibility events on boundaries. This is done
return False # by setting the effective visibility to the least restrictive
# of the old vs new.
if event.type == EventTypes.RoomHistoryVisibility:
prev_content = event.unsigned.get("prev_content", {})
prev_visibility = prev_content.get("history_visibility", None)
if prev_visibility not in VISIBILITY_PRIORITY:
prev_visibility = "shared"
new_priority = VISIBILITY_PRIORITY.index(visibility)
old_priority = VISIBILITY_PRIORITY.index(prev_visibility)
if old_priority < new_priority:
visibility = prev_visibility
# get the user's membership at the time of the event. (or rather,
# just *after* the event. Which means that people can see their
# own join events, but not (currently) their own leave events.)
membership_event = state.get((EventTypes.Member, user_id), None) membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event: if membership_event:
if membership_event.event_id in event_id_forgotten: if membership_event.event_id in event_id_forgotten:
@ -93,20 +127,29 @@ class BaseHandler(object):
else: else:
membership = None membership = None
# if the user was a member of the room at the time of the event,
# they can see it.
if membership == Membership.JOIN: if membership == Membership.JOIN:
return True return True
if event.type == EventTypes.RoomHistoryVisibility: if visibility == "joined":
return not is_peeking # we weren't a member at the time of the event, so we can't
# see this event.
return False
if visibility == "shared":
return True
elif visibility == "joined":
return membership == Membership.JOIN
elif visibility == "invited": elif visibility == "invited":
# user can also see the event if they were *invited* at the time
# of the event.
return membership == Membership.INVITE return membership == Membership.INVITE
return True else:
# visibility is shared: user can also see the event if they have
# become a member since the event
#
# XXX: if the user has subsequently joined and then left again,
# ideally we would share history up to the point they left. But
# we don't know when they left.
return not is_peeking
defer.returnValue({ defer.returnValue({
user_id: [ user_id: [
@ -119,7 +162,17 @@ class BaseHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_peeking=False): def _filter_events_for_client(self, user_id, events, is_peeking=False):
# Assumes that user has at some point joined the room if not is_guest. """
Check which events a user is allowed to see
:param str user_id: user id to be checked
:param [synapse.events.EventBase] events: list of events to be checked
:param bool is_peeking should be True if:
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the given
events
:rtype [synapse.events.EventBase]
"""
types = ( types = (
(EventTypes.RoomHistoryVisibility, ""), (EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id), (EventTypes.Member, user_id),
@ -128,15 +181,15 @@ class BaseHandler(object):
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
types=types types=types
) )
res = yield self._filter_events_for_clients( res = yield self.filter_events_for_clients(
[(user_id, is_peeking)], events, event_id_to_state [(user_id, is_peeking)], events, event_id_to_state
) )
defer.returnValue(res.get(user_id, [])) defer.returnValue(res.get(user_id, []))
def ratelimit(self, user_id): def ratelimit(self, requester):
time_now = self.clock.time() time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message( allowed, time_allowed = self.ratelimiter.send_message(
user_id, time_now, requester.user.to_string(), time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second, msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_message_burst_count, burst_count=self.hs.config.rc_message_burst_count,
) )
@ -147,7 +200,7 @@ class BaseHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder): def _create_new_client_event(self, builder):
latest_ret = yield self.store.get_latest_events_in_room( latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
builder.room_id, builder.room_id,
) )
@ -156,7 +209,10 @@ class BaseHandler(object):
else: else:
depth = 1 depth = 1
prev_events = [(e, h) for e, h, _ in latest_ret] prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in latest_ret
]
builder.prev_events = prev_events builder.prev_events = prev_events
builder.depth = depth builder.depth = depth
@ -165,6 +221,50 @@ class BaseHandler(object):
context = yield state_handler.compute_event_context(builder) context = yield state_handler.compute_event_context(builder)
# If we've received an invite over federation, there are no latest
# events in the room, because we don't know enough about the graph
# fragment we received to treat it like a graph, so the above returned
# no relevant events. It may have returned some events (if we have
# joined and left the room), but not useful ones, like the invite.
if (
not self.is_host_in_room(context.current_state) and
builder.type == EventTypes.Member
):
prev_member_event = yield self.store.get_room_member(
builder.sender, builder.room_id
)
# The prev_member_event may already be in context.current_state,
# despite us not being present in the room; in particular, if
# inviting user, and all other local users, have already left.
#
# In that case, we have all the information we need, and we don't
# want to drop "context" - not least because we may need to handle
# the invite locally, which will require us to have the whole
# context (not just prev_member_event) to auth it.
#
context_event_ids = (
e.event_id for e in context.current_state.values()
)
if (
prev_member_event and
prev_member_event.event_id not in context_event_ids
):
# The prev_member_event is missing from context, so it must
# have arrived over federation and is an outlier. We forcibly
# set our context to the invite we received over federation
builder.prev_events = (
prev_member_event.event_id,
prev_member_event.prev_events
)
context = yield state_handler.compute_event_context(
builder,
old_state=(prev_member_event,),
outlier=True
)
if builder.is_state(): if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes( builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events context.prev_state_events
@ -187,10 +287,40 @@ class BaseHandler(object):
(event, context,) (event, context,)
) )
def is_host_in_room(self, current_state):
room_members = [
(state_key, event.membership)
for ((event_type, state_key), event) in current_state.items()
if event_type == EventTypes.Member
]
if len(room_members) == 0:
# Have we just created the room, and is this about to be the very
# first member event?
create_event = current_state.get(("m.room.create", ""))
if create_event:
return True
for (state_key, membership) in room_members:
if (
UserID.from_string(state_key).domain == self.hs.hostname
and membership == Membership.JOIN
):
return True
return False
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_new_client_event(self, event, context, extra_users=[]): def handle_new_client_event(
self,
requester,
event,
context,
ratelimit=True,
extra_users=[]
):
# We now need to go and hit out to wherever we need to hit out to. # We now need to go and hit out to wherever we need to hit out to.
if ratelimit:
self.ratelimit(requester)
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
yield self.maybe_kick_guest_users(event, context.current_state.values()) yield self.maybe_kick_guest_users(event, context.current_state.values())
@ -215,6 +345,12 @@ class BaseHandler(object):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE: if event.content["membership"] == Membership.INVITE:
def is_inviter_member_event(e):
return (
e.type == EventTypes.Member and
e.sender == event.sender
)
event.unsigned["invite_room_state"] = [ event.unsigned["invite_room_state"] = [
{ {
"type": e.type, "type": e.type,
@ -223,12 +359,8 @@ class BaseHandler(object):
"sender": e.sender, "sender": e.sender,
} }
for k, e in context.current_state.items() for k, e in context.current_state.items()
if e.type in ( if e.type in self.hs.config.room_invite_state_types
EventTypes.JoinRules, or is_inviter_member_event(e)
EventTypes.CanonicalAlias,
EventTypes.RoomAvatar,
EventTypes.Name,
)
] ]
invitee = UserID.from_string(event.state_key) invitee = UserID.from_string(event.state_key)
@ -264,6 +396,12 @@ class BaseHandler(object):
"You don't have permission to redact events" "You don't have permission to redact events"
) )
if event.type == EventTypes.Create and context.current_state:
raise AuthError(
403,
"Changing the room create event is forbidden",
)
action_generator = ActionGenerator(self.hs) action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event( yield action_generator.handle_push_actions_for_event(
event, context, self event, context, self
@ -316,7 +454,8 @@ class BaseHandler(object):
if member_event.type != EventTypes.Member: if member_event.type != EventTypes.Member:
continue continue
if not self.hs.is_mine(UserID.from_string(member_event.state_key)): target_user = UserID.from_string(member_event.state_key)
if not self.hs.is_mine(target_user):
continue continue
if member_event.content["membership"] not in { if member_event.content["membership"] not in {
@ -338,18 +477,13 @@ class BaseHandler(object):
# and having homeservers have their own users leave keeps more # and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having # of that decision-making and control local to the guest-having
# homeserver. # homeserver.
message_handler = self.hs.get_handlers().message_handler requester = Requester(target_user, "", True)
yield message_handler.create_and_send_event( handler = self.hs.get_handlers().room_member_handler
{ yield handler.update_membership(
"type": EventTypes.Member, requester,
"state_key": member_event.state_key, target_user,
"content": { member_event.room_id,
"membership": Membership.LEAVE, "leave",
"kind": "guest"
},
"room_id": member_event.room_id,
"sender": member_event.state_key
},
ratelimit=False, ratelimit=False,
) )
except Exception as e: except Exception as e:

View File

@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
class AuthHandler(BaseHandler): class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs): def __init__(self, hs):
super(AuthHandler, self).__init__(hs) super(AuthHandler, self).__init__(hs)
@ -66,15 +67,18 @@ class AuthHandler(BaseHandler):
'auth' key: this method prompts for auth if none is sent. 'auth' key: this method prompts for auth if none is sent.
clientip (str): The IP address of the client. clientip (str): The IP address of the client.
Returns: Returns:
A tuple of (authed, dict, dict) where authed is true if the client A tuple of (authed, dict, dict, session_id) where authed is true if
has successfully completed an auth flow. If it is true, the first the client has successfully completed an auth flow. If it is true
dict contains the authenticated credentials of each stage. the first dict contains the authenticated credentials of each stage.
If authed is false, the first dictionary is the server response to If authed is false, the first dictionary is the server response to
the login request and should be passed back to the client. the login request and should be passed back to the client.
In either case, the second dict contains the parameters for this In either case, the second dict contains the parameters for this
request (which may have been given only in a previous call). request (which may have been given only in a previous call).
session_id is the ID of this session, either passed in by the client
or assigned by the call to check_auth
""" """
authdict = None authdict = None
@ -103,7 +107,10 @@ class AuthHandler(BaseHandler):
if not authdict: if not authdict:
defer.returnValue( defer.returnValue(
(False, self._auth_dict_for_flows(flows, session), clientdict) (
False, self._auth_dict_for_flows(flows, session),
clientdict, session['id']
)
) )
if 'creds' not in session: if 'creds' not in session:
@ -122,12 +129,11 @@ class AuthHandler(BaseHandler):
for f in flows: for f in flows:
if len(set(f) - set(creds.keys())) == 0: if len(set(f) - set(creds.keys())) == 0:
logger.info("Auth completed with creds: %r", creds) logger.info("Auth completed with creds: %r", creds)
self._remove_session(session) defer.returnValue((True, creds, clientdict, session['id']))
defer.returnValue((True, creds, clientdict))
ret = self._auth_dict_for_flows(flows, session) ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys() ret['completed'] = creds.keys()
defer.returnValue((False, ret, clientdict)) defer.returnValue((False, ret, clientdict, session['id']))
@defer.inlineCallbacks @defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip): def add_oob_auth(self, stagetype, authdict, clientip):
@ -154,6 +160,43 @@ class AuthHandler(BaseHandler):
defer.returnValue(True) defer.returnValue(True)
defer.returnValue(False) defer.returnValue(False)
def get_session_id(self, clientdict):
"""
Gets the session ID for a client given the client dictionary
:param clientdict: The dictionary sent by the client in the request
:return: The string session ID the client sent. If the client did not
send a session ID, returns None.
"""
sid = None
if clientdict and 'auth' in clientdict:
authdict = clientdict['auth']
if 'session' in authdict:
sid = authdict['session']
return sid
def set_session_data(self, session_id, key, value):
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
the client.
:param session_id: (string) The ID of this session as returned from check_auth
:param key: (string) The key to store the data under
:param value: (any) The data to store
"""
sess = self._get_session_info(session_id)
sess.setdefault('serverdict', {})[key] = value
self._save_session(sess)
def get_session_data(self, session_id, key, default=None):
"""
Retrieve data stored with set_session_data
:param session_id: (string) The ID of this session as returned from check_auth
:param key: (string) The key to store the data under
:param default: (any) Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password_auth(self, authdict, _): def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict: if "user" not in authdict or "password" not in authdict:
@ -432,13 +475,18 @@ class AuthHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword): def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword) password_hash = self.hash(newpassword)
except_access_token_ids = [requester.access_token_id] if requester else []
yield self.store.user_set_password_hash(user_id, password_hash) yield self.store.user_set_password_hash(user_id, password_hash)
yield self.store.user_delete_access_tokens(user_id) yield self.store.user_delete_access_tokens(
yield self.hs.get_pusherpool().remove_pushers_by_user(user_id) user_id, except_access_token_ids
yield self.store.flush_user(user_id) )
yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_ids
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at): def add_threepid(self, user_id, medium, address, validated_at):
@ -450,11 +498,18 @@ class AuthHandler(BaseHandler):
def _save_session(self, session): def _save_session(self, session):
# TODO: Persistent storage # TODO: Persistent storage
logger.debug("Saving session %s", session) logger.debug("Saving session %s", session)
session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session self.sessions[session["id"]] = session
self._prune_sessions()
def _remove_session(self, session): def _prune_sessions(self):
logger.debug("Removing session %s", session) for sid, sess in self.sessions.items():
del self.sessions[session["id"]] last_used = 0
if 'last_used' in sess:
last_used = sess['last_used']
now = self.hs.get_clock().time_msec()
if last_used < now - AuthHandler.SESSION_EXPIRE_MS:
del self.sessions[sid]
def hash(self, password): def hash(self, password):
"""Computes a secure hash of password. """Computes a secure hash of password.
@ -477,4 +532,4 @@ class AuthHandler(BaseHandler):
Returns: Returns:
Whether self.hash(password) == stored_hash (bool). Whether self.hash(password) == stored_hash (bool).
""" """
return bcrypt.checkpw(password, stored_hash) return bcrypt.hashpw(password, stored_hash) == stored_hash

View File

@ -17,9 +17,9 @@
from twisted.internet import defer from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import SynapseError, Codes, CodeMessageException from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomAlias from synapse.types import RoomAlias, UserID
import logging import logging
import string import string
@ -32,13 +32,15 @@ class DirectoryHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(DirectoryHandler, self).__init__(hs) super(DirectoryHandler, self).__init__(hs)
self.state = hs.get_state_handler()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.federation.register_query_handler( self.federation.register_query_handler(
"directory", self.on_directory_query "directory", self.on_directory_query
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_association(self, room_alias, room_id, servers=None): def _create_association(self, room_alias, room_id, servers=None, creator=None):
# general association creation for both human users and app services # general association creation for both human users and app services
for wchar in string.whitespace: for wchar in string.whitespace:
@ -60,7 +62,8 @@ class DirectoryHandler(BaseHandler):
yield self.store.create_room_alias_association( yield self.store.create_room_alias_association(
room_alias, room_alias,
room_id, room_id,
servers servers,
creator=creator,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -77,7 +80,7 @@ class DirectoryHandler(BaseHandler):
400, "This alias is reserved by an application service.", 400, "This alias is reserved by an application service.",
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
yield self._create_association(room_alias, room_id, servers) yield self._create_association(room_alias, room_id, servers, creator=user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def create_appservice_association(self, service, room_alias, room_id, def create_appservice_association(self, service, room_alias, room_id,
@ -92,10 +95,14 @@ class DirectoryHandler(BaseHandler):
yield self._create_association(room_alias, room_id, servers) yield self._create_association(room_alias, room_id, servers)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_association(self, user_id, room_alias): def delete_association(self, requester, user_id, room_alias):
# association deletion for human users # association deletion for human users
# TODO Check if server admin can_delete = yield self._user_can_delete_alias(room_alias, user_id)
if not can_delete:
raise AuthError(
403, "You don't have permission to delete the alias.",
)
can_delete = yield self.can_modify_alias( can_delete = yield self.can_modify_alias(
room_alias, room_alias,
@ -107,7 +114,25 @@ class DirectoryHandler(BaseHandler):
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
yield self._delete_association(room_alias) room_id = yield self._delete_association(room_alias)
try:
yield self.send_room_alias_update_event(
requester,
requester.user.to_string(),
room_id
)
yield self._update_canonical_alias(
requester,
requester.user.to_string(),
room_id,
room_alias,
)
except AuthError as e:
logger.info("Failed to update alias events: %s", e)
defer.returnValue(room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_appservice_association(self, service, room_alias): def delete_appservice_association(self, service, room_alias):
@ -124,11 +149,9 @@ class DirectoryHandler(BaseHandler):
if not self.hs.is_mine(room_alias): if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local") raise SynapseError(400, "Room alias must be local")
yield self.store.delete_room_alias(room_alias) room_id = yield self.store.delete_room_alias(room_alias)
# TODO - Looks like _update_room_alias_event has never been implemented defer.returnValue(room_id)
# if room_id:
# yield self._update_room_alias_events(user_id, room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_association(self, room_alias): def get_association(self, room_alias):
@ -212,17 +235,44 @@ class DirectoryHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def send_room_alias_update_event(self, user_id, room_id): def send_room_alias_update_event(self, requester, user_id, room_id):
aliases = yield self.store.get_aliases_for_room(room_id) aliases = yield self.store.get_aliases_for_room(room_id)
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_event({ yield msg_handler.create_and_send_nonmember_event(
"type": EventTypes.Aliases, requester,
"state_key": self.hs.hostname, {
"room_id": room_id, "type": EventTypes.Aliases,
"sender": user_id, "state_key": self.hs.hostname,
"content": {"aliases": aliases}, "room_id": room_id,
}, ratelimit=False) "sender": user_id,
"content": {"aliases": aliases},
},
ratelimit=False
)
@defer.inlineCallbacks
def _update_canonical_alias(self, requester, user_id, room_id, room_alias):
alias_event = yield self.state.get_current_state(
room_id, EventTypes.CanonicalAlias, ""
)
alias_str = room_alias.to_string()
if not alias_event or alias_event.content.get("alias", "") != alias_str:
return
msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
"state_key": "",
"room_id": room_id,
"sender": user_id,
"content": {},
},
ratelimit=False
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_association_from_room_alias(self, room_alias): def get_association_from_room_alias(self, room_alias):
@ -257,3 +307,35 @@ class DirectoryHandler(BaseHandler):
return return
# either no interested services, or no service with an exclusive lock # either no interested services, or no service with an exclusive lock
defer.returnValue(True) defer.returnValue(True)
@defer.inlineCallbacks
def _user_can_delete_alias(self, alias, user_id):
creator = yield self.store.get_room_alias_creator(alias.to_string())
if creator and creator == user_id:
defer.returnValue(True)
is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
defer.returnValue(is_admin)
@defer.inlineCallbacks
def edit_published_room_list(self, requester, room_id, visibility):
"""Edit the entry of the room in the published room list.
requester
room_id (str)
visibility (str): "public" or "private"
"""
if requester.is_guest:
raise AuthError(403, "Guests cannot edit the published room list")
if visibility not in ["public", "private"]:
raise SynapseError(400, "Invalid visibility setting")
room = yield self.store.get_room(room_id)
if room is None:
raise SynapseError(400, "Unknown room")
yield self.auth.check_can_change_room_list(room_id, requester.user)
yield self.store.set_room_is_public(room_id, visibility == "public")

View File

@ -18,7 +18,8 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.types import UserID from synapse.types import UserID
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.util.logcontext import preserve_context_over_fn from synapse.api.constants import Membership, EventTypes
from synapse.events import EventBase
from ._base import BaseHandler from ._base import BaseHandler
@ -29,20 +30,6 @@ import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def started_user_eventstream(distributor, user):
return preserve_context_over_fn(
distributor.fire,
"started_user_eventstream", user
)
def stopped_user_eventstream(distributor, user):
return preserve_context_over_fn(
distributor.fire,
"stopped_user_eventstream", user
)
class EventStreamHandler(BaseHandler): class EventStreamHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
@ -61,61 +48,6 @@ class EventStreamHandler(BaseHandler):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@defer.inlineCallbacks
def started_stream(self, user):
"""Tells the presence handler that we have started an eventstream for
the user:
Args:
user (User): The user who started a stream.
Returns:
A deferred that completes once their presence has been updated.
"""
if user not in self._streams_per_user:
# Make sure we set the streams per user to 1 here rather than
# setting it to zero and incrementing the value below.
# Otherwise this may race with stopped_stream causing the
# user to be erased from the map before we have a chance
# to increment it.
self._streams_per_user[user] = 1
if user in self._stop_timer_per_user:
try:
self.clock.cancel_call_later(
self._stop_timer_per_user.pop(user)
)
except:
logger.exception("Failed to cancel event timer")
else:
yield started_user_eventstream(self.distributor, user)
else:
self._streams_per_user[user] += 1
def stopped_stream(self, user):
"""If there are no streams for a user this starts a timer that will
notify the presence handler that we haven't got an event stream for
the user unless the user starts a new stream in 30 seconds.
Args:
user (User): The user who stopped a stream.
"""
self._streams_per_user[user] -= 1
if not self._streams_per_user[user]:
del self._streams_per_user[user]
# 30 seconds of grace to allow the client to reconnect again
# before we think they're gone
def _later():
logger.debug("_later stopped_user_eventstream %s", user)
self._stop_timer_per_user.pop(user, None)
return stopped_user_eventstream(self.distributor, user)
logger.debug("Scheduling _later: for %s", user)
self._stop_timer_per_user[user] = (
self.clock.call_later(30, _later)
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0, def get_stream(self, auth_user_id, pagin_config, timeout=0,
@ -126,11 +58,12 @@ class EventStreamHandler(BaseHandler):
If `only_keys` is not None, events from keys will be sent down. If `only_keys` is not None, events from keys will be sent down.
""" """
auth_user = UserID.from_string(auth_user_id) auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_handlers().presence_handler
try: context = yield presence_handler.user_syncing(
if affect_presence: auth_user_id, affect_presence=affect_presence,
yield self.started_stream(auth_user) )
with context:
if timeout: if timeout:
# If they've set a timeout set a minimum limit. # If they've set a timeout set a minimum limit.
timeout = max(timeout, 500) timeout = max(timeout, 500)
@ -145,6 +78,34 @@ class EventStreamHandler(BaseHandler):
is_guest=is_guest, explicit_room_id=room_id is_guest=is_guest, explicit_room_id=room_id
) )
# When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users.
to_add = []
for event in events:
if not isinstance(event, EventBase):
continue
if event.type == EventTypes.Member:
if event.membership != Membership.JOIN:
continue
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
users = yield self.store.get_users_in_room(event.room_id)
states = yield presence_handler.get_states(
users,
as_event=True,
)
to_add.extend(states)
else:
ev = yield presence_handler.get_state(
UserID.from_string(event.state_key),
as_event=True,
)
to_add.append(ev)
events.extend(to_add)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
chunks = [ chunks = [
@ -159,10 +120,6 @@ class EventStreamHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
finally:
if affect_presence:
self.stopped_stream(auth_user)
class EventHandler(BaseHandler): class EventHandler(BaseHandler):

View File

@ -14,6 +14,9 @@
# limitations under the License. # limitations under the License.
"""Contains handlers for federation events.""" """Contains handlers for federation events."""
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
from ._base import BaseHandler from ._base import BaseHandler
@ -99,7 +102,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, backfilled, state=None, def on_receive_pdu(self, origin, pdu, state=None,
auth_chain=None): auth_chain=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to """ Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it through the StateHandler. do auth checks and put it through the StateHandler.
@ -120,7 +123,6 @@ class FederationHandler(BaseHandler):
# FIXME (erikj): Awful hack to make the case where we are not currently # FIXME (erikj): Awful hack to make the case where we are not currently
# in the room work # in the room work
current_state = None
is_in_room = yield self.auth.check_host_in_room( is_in_room = yield self.auth.check_host_in_room(
event.room_id, event.room_id,
self.server_name self.server_name
@ -183,8 +185,6 @@ class FederationHandler(BaseHandler):
origin, origin,
event, event,
state=state, state=state,
backfilled=backfilled,
current_state=current_state,
) )
except AuthError as e: except AuthError as e:
raise FederationError( raise FederationError(
@ -213,18 +213,17 @@ class FederationHandler(BaseHandler):
except StoreError: except StoreError:
logger.exception("Failed to store room.") logger.exception("Failed to store room.")
if not backfilled: extra_users = []
extra_users = [] if event.type == EventTypes.Member:
if event.type == EventTypes.Member: target_user_id = event.state_key
target_user_id = event.state_key target_user = UserID.from_string(target_user_id)
target_user = UserID.from_string(target_user_id) extra_users.append(target_user)
extra_users.append(target_user)
with PreserveLoggingContext(): with PreserveLoggingContext():
self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=extra_users extra_users=extra_users
) )
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
@ -469,7 +468,7 @@ class FederationHandler(BaseHandler):
limit=100, limit=100,
extremities=[e for e in extremities.keys()] extremities=[e for e in extremities.keys()]
) )
except SynapseError: except SynapseError as e:
logger.info( logger.info(
"Failed to backfill from %s because %s", "Failed to backfill from %s because %s",
dom, e, dom, e,
@ -644,7 +643,7 @@ class FederationHandler(BaseHandler):
continue continue
try: try:
self.on_receive_pdu(origin, p, backfilled=False) self.on_receive_pdu(origin, p)
except: except:
logger.exception("Couldn't handle pdu") logger.exception("Couldn't handle pdu")
@ -776,7 +775,6 @@ class FederationHandler(BaseHandler):
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, event,
context=context, context=context,
backfilled=False,
) )
target_user = UserID.from_string(event.state_key) target_user = UserID.from_string(event.state_key)
@ -810,7 +808,21 @@ class FederationHandler(BaseHandler):
target_hosts, target_hosts,
signed_event signed_event
) )
defer.returnValue(None)
context = yield self.state_handler.compute_event_context(event)
event_stream_id, max_stream_id = yield self.store.persist_event(
event,
context=context,
)
target_user = UserID.from_string(event.state_key)
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=[target_user],
)
defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
@ -1056,8 +1068,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _handle_new_event(self, origin, event, state=None, backfilled=False, def _handle_new_event(self, origin, event, state=None, auth_events=None):
current_state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier() outlier = event.internal_metadata.is_outlier()
@ -1067,7 +1078,7 @@ class FederationHandler(BaseHandler):
auth_events=auth_events, auth_events=auth_events,
) )
if not backfilled and not event.internal_metadata.is_outlier(): if not event.internal_metadata.is_outlier():
action_generator = ActionGenerator(self.hs) action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event( yield action_generator.handle_push_actions_for_event(
event, context, self event, context, self
@ -1076,9 +1087,7 @@ class FederationHandler(BaseHandler):
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, event,
context=context, context=context,
backfilled=backfilled, is_new_state=not outlier,
is_new_state=(not outlier and not backfilled),
current_state=current_state,
) )
defer.returnValue((context, event_stream_id, max_stream_id)) defer.returnValue((context, event_stream_id, max_stream_id))
@ -1176,7 +1185,6 @@ class FederationHandler(BaseHandler):
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, new_event_context, event, new_event_context,
backfilled=False,
is_new_state=True, is_new_state=True,
current_state=state, current_state=state,
) )
@ -1620,19 +1628,15 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def exchange_third_party_invite(self, invite): def exchange_third_party_invite(
sender = invite["sender"] self,
room_id = invite["room_id"] sender_user_id,
target_user_id,
if "signed" not in invite or "token" not in invite["signed"]: room_id,
logger.info( signed,
"Discarding received notification of third party invite " ):
"without signed: %s" % (invite,)
)
return
third_party_invite = { third_party_invite = {
"signed": invite["signed"], "signed": signed,
} }
event_dict = { event_dict = {
@ -1642,8 +1646,8 @@ class FederationHandler(BaseHandler):
"third_party_invite": third_party_invite, "third_party_invite": third_party_invite,
}, },
"room_id": room_id, "room_id": room_id,
"sender": sender, "sender": sender_user_id,
"state_key": invite["mxid"], "state_key": target_user_id,
} }
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
@ -1656,11 +1660,11 @@ class FederationHandler(BaseHandler):
) )
self.auth.check(event, context.current_state) self.auth.check(event, context.current_state)
yield self._validate_keyserver(event, auth_events=context.current_state) yield self._check_signature(event, auth_events=context.current_state)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(event, context) yield member_handler.send_membership_event(None, event, context)
else: else:
destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)]) destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
yield self.replication_layer.forward_third_party_invite( yield self.replication_layer.forward_third_party_invite(
destinations, destinations,
room_id, room_id,
@ -1681,13 +1685,13 @@ class FederationHandler(BaseHandler):
) )
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
yield self._validate_keyserver(event, auth_events=context.current_state) yield self._check_signature(event, auth_events=context.current_state)
returned_invite = yield self.send_invite(origin, event) returned_invite = yield self.send_invite(origin, event)
# TODO: Make sure the signatures actually are correct. # TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures) event.signatures.update(returned_invite.signatures)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(event, context) yield member_handler.send_membership_event(None, event, context)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_display_name_to_third_party_invite(self, event_dict, event, context): def add_display_name_to_third_party_invite(self, event_dict, event, context):
@ -1711,17 +1715,69 @@ class FederationHandler(BaseHandler):
defer.returnValue((event, context)) defer.returnValue((event, context))
@defer.inlineCallbacks @defer.inlineCallbacks
def _validate_keyserver(self, event, auth_events): def _check_signature(self, event, auth_events):
token = event.content["third_party_invite"]["signed"]["token"] """
Checks that the signature in the event is consistent with its invite.
:param event (Event): The m.room.member event to check
:param auth_events (dict<(event type, state_key), event>)
:raises
AuthError if signature didn't match any keys, or key has been
revoked,
SynapseError if a transient error meant a key couldn't be checked
for revocation.
"""
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]
invite_event = auth_events.get( invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,) (EventTypes.ThirdPartyInvite, token,)
) )
if not invite_event:
raise AuthError(403, "Could not find invite")
last_exception = None
for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
try:
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"):
continue
public_key = public_key_object["public_key"]
verify_key = decode_verify_key_bytes(
key_name,
decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
if "key_validity_url" in public_key_object:
yield self._check_key_revocation(
public_key,
public_key_object["key_validity_url"]
)
return
except Exception as e:
last_exception = e
raise last_exception
@defer.inlineCallbacks
def _check_key_revocation(self, public_key, url):
"""
Checks whether public_key has been revoked.
:param public_key (str): base-64 encoded public key.
:param url (str): Key revocation URL.
:raises
AuthError if they key has been revoked.
SynapseError if a transient error meant a key couldn't be checked
for revocation.
"""
try: try:
response = yield self.hs.get_simple_http_client().get_json( response = yield self.hs.get_simple_http_client().get_json(
invite_event.content["key_validity_url"], url,
{"public_key": invite_event.content["public_key"]} {"public_key": public_key}
) )
except Exception: except Exception:
raise SynapseError( raise SynapseError(

View File

@ -16,12 +16,11 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.types import UserID, RoomStreamToken, StreamToken from synapse.types import UserID, RoomStreamToken, StreamToken
@ -197,12 +196,25 @@ class MessageHandler(BaseHandler):
if builder.type == EventTypes.Member: if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None) membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
if membership == Membership.JOIN: if membership == Membership.JOIN:
joinee = UserID.from_string(builder.state_key)
# If event doesn't include a display name, add one. # If event doesn't include a display name, add one.
yield collect_presencelike_data( yield collect_presencelike_data(
self.distributor, joinee, builder.content self.distributor, target, builder.content
) )
elif membership == Membership.INVITE:
profile = self.hs.get_handlers().profile_handler
content = builder.content
try:
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s",
target, e
)
if token_id is not None: if token_id is not None:
builder.internal_metadata.token_id = token_id builder.internal_metadata.token_id = token_id
@ -216,7 +228,7 @@ class MessageHandler(BaseHandler):
defer.returnValue((event, context)) defer.returnValue((event, context))
@defer.inlineCallbacks @defer.inlineCallbacks
def send_event(self, event, context, ratelimit=True, is_guest=False): def send_nonmember_event(self, requester, event, context, ratelimit=True):
""" """
Persists and notifies local clients and federation of an event. Persists and notifies local clients and federation of an event.
@ -226,55 +238,70 @@ class MessageHandler(BaseHandler):
ratelimit (bool): Whether to rate limit this send. ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest. is_guest (bool): Whether the sender is a guest.
""" """
if event.type == EventTypes.Member:
raise SynapseError(
500,
"Tried to send member event through non-member codepath"
)
user = UserID.from_string(event.sender) user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,) assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if ratelimit:
self.ratelimit(event.sender)
if event.is_state(): if event.is_state():
prev_state = context.current_state.get((event.type, event.state_key)) prev_state = self.deduplicate_state_event(event, context)
if prev_state and event.user_id == prev_state.user_id: if prev_state is not None:
prev_content = encode_canonical_json(prev_state.content) defer.returnValue(prev_state)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
# Duplicate suppression for state updates with same sender
# and content.
defer.returnValue(prev_state)
if event.type == EventTypes.Member: yield self.handle_new_client_event(
member_handler = self.hs.get_handlers().room_member_handler requester=requester,
yield member_handler.send_membership_event(event, context, is_guest=is_guest) event=event,
else: context=context,
yield self.handle_new_client_event( ratelimit=ratelimit,
event=event, )
context=context,
)
if event.type == EventTypes.Message: if event.type == EventTypes.Message:
presence = self.hs.get_handlers().presence_handler presence = self.hs.get_handlers().presence_handler
with PreserveLoggingContext(): yield presence.bump_presence_active_time(user)
presence.bump_presence_active_time(user)
def deduplicate_state_event(self, event, context):
"""
Checks whether event is in the latest resolved state in context.
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_event = context.current_state.get((event.type, event.state_key))
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
return prev_event
return None
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True, def create_and_send_nonmember_event(
token_id=None, txn_id=None, is_guest=False): self,
requester,
event_dict,
ratelimit=True,
txn_id=None
):
""" """
Creates an event, then sends it. Creates an event, then sends it.
See self.create_event and self.send_event. See self.create_event and self.send_nonmember_event.
""" """
event, context = yield self.create_event( event, context = yield self.create_event(
event_dict, event_dict,
token_id=token_id, token_id=requester.access_token_id,
txn_id=txn_id txn_id=txn_id
) )
yield self.send_event( yield self.send_nonmember_event(
requester,
event, event,
context, context,
ratelimit=ratelimit, ratelimit=ratelimit,
is_guest=is_guest
) )
defer.returnValue(event) defer.returnValue(event)
@ -635,8 +662,8 @@ class MessageHandler(BaseHandler):
user_id, messages, is_peeking=is_peeking user_id, messages, is_peeking=is_peeking
) )
start_token = StreamToken(token[0], 0, 0, 0, 0) start_token = StreamToken.START.copy_and_replace("room_key", token[0])
end_token = StreamToken(token[1], 0, 0, 0, 0) end_token = StreamToken.START.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
@ -660,10 +687,6 @@ class MessageHandler(BaseHandler):
room_id=room_id, room_id=room_id,
) )
# TODO(paul): I wish I was called with user objects not user_id
# strings...
auth_user = UserID.from_string(user_id)
# TODO: These concurrently # TODO: These concurrently
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
state = [ state = [
@ -688,13 +711,11 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_presence(): def get_presence():
states = yield presence_handler.get_states( states = yield presence_handler.get_states(
target_users=[UserID.from_string(m.user_id) for m in room_members], [m.user_id for m in room_members],
auth_user=auth_user,
as_event=True, as_event=True,
check_auth=False,
) )
defer.returnValue(states.values()) defer.returnValue(states)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_receipts(): def get_receipts():

File diff suppressed because it is too large Load Diff

View File

@ -16,8 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.api.constants import EventTypes, Membership from synapse.types import UserID, Requester
from synapse.types import UserID
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from ._base import BaseHandler from ._base import BaseHandler
@ -49,6 +48,9 @@ class ProfileHandler(BaseHandler):
distributor = hs.get_distributor() distributor = hs.get_distributor()
self.distributor = distributor self.distributor = distributor
distributor.declare("collect_presencelike_data")
distributor.declare("changed_presencelike_data")
distributor.observe("registered_user", self.registered_user) distributor.observe("registered_user", self.registered_user)
distributor.observe( distributor.observe(
@ -87,13 +89,13 @@ class ProfileHandler(BaseHandler):
defer.returnValue(result["displayname"]) defer.returnValue(result["displayname"])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_displayname(self, target_user, auth_user, new_displayname): def set_displayname(self, target_user, requester, new_displayname):
"""target_user is the user whose displayname is to be changed; """target_user is the user whose displayname is to be changed;
auth_user is the user attempting to make this change.""" auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user: if target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname") raise AuthError(400, "Cannot set another user's displayname")
if new_displayname == '': if new_displayname == '':
@ -107,7 +109,7 @@ class ProfileHandler(BaseHandler):
"displayname": new_displayname, "displayname": new_displayname,
}) })
yield self._update_join_states(target_user) yield self._update_join_states(requester)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_avatar_url(self, target_user): def get_avatar_url(self, target_user):
@ -137,13 +139,13 @@ class ProfileHandler(BaseHandler):
defer.returnValue(result["avatar_url"]) defer.returnValue(result["avatar_url"])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_avatar_url(self, target_user, auth_user, new_avatar_url): def set_avatar_url(self, target_user, requester, new_avatar_url):
"""target_user is the user whose avatar_url is to be changed; """target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change.""" auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user: if target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url") raise AuthError(400, "Cannot set another user's avatar_url")
yield self.store.set_profile_avatar_url( yield self.store.set_profile_avatar_url(
@ -154,7 +156,7 @@ class ProfileHandler(BaseHandler):
"avatar_url": new_avatar_url, "avatar_url": new_avatar_url,
}) })
yield self._update_join_states(target_user) yield self._update_join_states(requester)
@defer.inlineCallbacks @defer.inlineCallbacks
def collect_presencelike_data(self, user, state): def collect_presencelike_data(self, user, state):
@ -197,32 +199,30 @@ class ProfileHandler(BaseHandler):
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
def _update_join_states(self, user): def _update_join_states(self, requester):
user = requester.user
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
return return
self.ratelimit(user.to_string()) self.ratelimit(requester)
joins = yield self.store.get_rooms_for_user( joins = yield self.store.get_rooms_for_user(
user.to_string(), user.to_string(),
) )
for j in joins: for j in joins:
content = { handler = self.hs.get_handlers().room_member_handler
"membership": Membership.JOIN,
}
yield collect_presencelike_data(self.distributor, user, content)
msg_handler = self.hs.get_handlers().message_handler
try: try:
yield msg_handler.create_and_send_event({ # Assume the user isn't a guest because we don't let guests set
"type": EventTypes.Member, # profile or avatar data.
"room_id": j.room_id, requester = Requester(user, "", False)
"state_key": user.to_string(), yield handler.update_membership(
"content": content, requester,
"sender": user.to_string() user,
}, ratelimit=False) j.room_id,
"join", # We treat a profile update like a join.
ratelimit=False, # Try to hide that these events aren't atomic.
)
except Exception as e: except Exception as e:
logger.warn( logger.warn(
"Failed to update join event for room %s - %s", "Failed to update join event for room %s - %s",

View File

@ -36,8 +36,6 @@ class ReceiptsHandler(BaseHandler):
) )
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self._receipt_cache = None
@defer.inlineCallbacks @defer.inlineCallbacks
def received_client_receipt(self, room_id, receipt_type, user_id, def received_client_receipt(self, room_id, receipt_type, user_id,
event_id): event_id):

View File

@ -47,7 +47,8 @@ class RegistrationHandler(BaseHandler):
self._next_generated_user_id = None self._next_generated_user_id = None
@defer.inlineCallbacks @defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None): def check_username(self, localpart, guest_access_token=None,
assigned_user_id=None):
yield run_on_reactor() yield run_on_reactor()
if urllib.quote(localpart.encode('utf-8')) != localpart: if urllib.quote(localpart.encode('utf-8')) != localpart:
@ -60,7 +61,16 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id) if assigned_user_id:
if user_id == assigned_user_id:
return
else:
raise SynapseError(
400,
"A different user ID has already been registered for this session",
)
yield self.check_user_id_not_appservice_exclusive(user_id)
users = yield self.store.get_users_by_id_case_insensitive(user_id) users = yield self.store.get_users_by_id_case_insensitive(user_id)
if users: if users:
@ -145,7 +155,7 @@ class RegistrationHandler(BaseHandler):
localpart = yield self._generate_user_id(attempts > 0) localpart = yield self._generate_user_id(attempts > 0)
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_not_appservice_exclusive(user_id)
if generate_token: if generate_token:
token = self.auth_handler().generate_access_token(user_id) token = self.auth_handler().generate_access_token(user_id)
try: try:
@ -157,6 +167,7 @@ class RegistrationHandler(BaseHandler):
) )
except SynapseError: except SynapseError:
# if user id is taken, just generate another # if user id is taken, just generate another
user = None
user_id = None user_id = None
token = None token = None
attempts += 1 attempts += 1
@ -180,11 +191,19 @@ class RegistrationHandler(BaseHandler):
400, "Invalid user localpart for this application service.", 400, "Invalid user localpart for this application service.",
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
service_id = service.id if service.is_exclusive_user(user_id) else None
yield self.check_user_id_not_appservice_exclusive(
user_id, allowed_appservice=service
)
token = self.auth_handler().generate_access_token(user_id) token = self.auth_handler().generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash="" password_hash="",
appservice_id=service_id,
) )
yield registered_user(self.distributor, user) yield registered_user(self.distributor, user)
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))
@ -226,7 +245,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_not_appservice_exclusive(user_id)
token = self.auth_handler().generate_access_token(user_id) token = self.auth_handler().generate_access_token(user_id)
try: try:
yield self.store.register( yield self.store.register(
@ -235,7 +254,7 @@ class RegistrationHandler(BaseHandler):
password_hash=None password_hash=None
) )
yield registered_user(self.distributor, user) yield registered_user(self.distributor, user)
except Exception, e: except Exception as e:
yield self.store.add_access_token_to_user(user_id, token) yield self.store.add_access_token_to_user(user_id, token)
# Ignore Registration errors # Ignore Registration errors
logger.exception(e) logger.exception(e)
@ -278,12 +297,14 @@ class RegistrationHandler(BaseHandler):
yield identity_handler.bind_threepid(c, user_id) yield identity_handler.bind_threepid(c, user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_user_id_is_valid(self, user_id): def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
# valid user IDs must not clash with any user ID namespaces claimed by # valid user IDs must not clash with any user ID namespaces claimed by
# application services. # application services.
services = yield self.store.get_app_services() services = yield self.store.get_app_services()
interested_services = [ interested_services = [
s for s in services if s.is_interested_in_user(user_id) s for s in services
if s.is_interested_in_user(user_id)
and s != allowed_appservice
] ]
for service in interested_services: for service in interested_services:
if service.is_exclusive_user(user_id): if service.is_exclusive_user(user_id):
@ -342,3 +363,18 @@ class RegistrationHandler(BaseHandler):
def auth_handler(self): def auth_handler(self):
return self.hs.get_handlers().auth_handler return self.hs.get_handlers().auth_handler
@defer.inlineCallbacks
def guest_access_token_for(self, medium, address, inviter_user_id):
access_token = yield self.store.get_3pid_guest_access_token(medium, address)
if access_token:
defer.returnValue(access_token)
_, access_token = yield self.register(
generate_token=True,
make_guest=True
)
access_token = yield self.store.save_or_get_3pid_guest_access_token(
medium, address, access_token, inviter_user_id
)
defer.returnValue(access_token)

File diff suppressed because it is too large Load Diff

View File

@ -20,6 +20,7 @@ from synapse.api.constants import Membership, EventTypes
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.push.clientformat import format_push_rules_for_user
from twisted.internet import defer from twisted.internet import defer
@ -121,7 +122,11 @@ class SyncResult(collections.namedtuple("SyncResult", [
events. events.
""" """
return bool( return bool(
self.presence or self.joined or self.invited or self.archived self.presence or
self.joined or
self.invited or
self.archived or
self.account_data
) )
@ -205,9 +210,9 @@ class SyncHandler(BaseHandler):
key=None key=None
) )
membership_list = (Membership.INVITE, Membership.JOIN) membership_list = (
if sync_config.filter_collection.include_leave: Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN
membership_list += (Membership.LEAVE, Membership.BAN) )
room_list = yield self.store.get_rooms_for_user_where_membership_is( room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=sync_config.user.to_string(), user_id=sync_config.user.to_string(),
@ -220,6 +225,10 @@ class SyncHandler(BaseHandler):
) )
) )
account_data['m.push_rules'] = yield self.push_rules_for_user(
sync_config.user
)
tags_by_room = yield self.store.get_tags_for_user( tags_by_room = yield self.store.get_tags_for_user(
sync_config.user.to_string() sync_config.user.to_string()
) )
@ -253,6 +262,12 @@ class SyncHandler(BaseHandler):
invite=invite, invite=invite,
)) ))
elif event.membership in (Membership.LEAVE, Membership.BAN): elif event.membership in (Membership.LEAVE, Membership.BAN):
# Always send down rooms we were banned or kicked from.
if not sync_config.filter_collection.include_leave:
if event.membership == Membership.LEAVE:
if sync_config.user.to_string() == event.sender:
continue
leave_token = now_token.copy_and_replace( leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,) "room_key", "s%d" % (event.stream_ordering,)
) )
@ -318,6 +333,14 @@ class SyncHandler(BaseHandler):
defer.returnValue(room_sync) defer.returnValue(room_sync)
@defer.inlineCallbacks
def push_rules_for_user(self, user):
user_id = user.to_string()
rawrules = yield self.store.get_push_rules_for_user(user_id)
enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
rules = format_push_rules_for_user(user, rawrules, enabled_map)
defer.returnValue(rules)
def account_data_for_user(self, account_data): def account_data_for_user(self, account_data):
account_data_events = [] account_data_events = []
@ -477,6 +500,15 @@ class SyncHandler(BaseHandler):
) )
) )
push_rules_changed = yield self.store.have_push_rules_changed_for_user(
user_id, int(since_token.push_rules_key)
)
if push_rules_changed:
account_data["m.push_rules"] = yield self.push_rules_for_user(
sync_config.user
)
# Get a list of membership change events that have happened. # Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_membership_changes_for_user( rooms_changed = yield self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key user_id, since_token.room_key, now_token.room_key
@ -582,6 +614,28 @@ class SyncHandler(BaseHandler):
if room_sync: if room_sync:
joined.append(room_sync) joined.append(room_sync)
# For each newly joined room, we want to send down presence of
# existing users.
presence_handler = self.hs.get_handlers().presence_handler
extra_presence_users = set()
for room_id in newly_joined_rooms:
users = yield self.store.get_users_in_room(event.room_id)
extra_presence_users.update(users)
# For each new member, send down presence.
for joined_sync in joined:
it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values())
for event in it:
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
extra_presence_users.add(event.state_key)
states = yield presence_handler.get_states(
[u for u in extra_presence_users if u != user_id],
as_event=True,
)
presence.extend(states)
account_data_for_user = sync_config.filter_collection.filter_account_data( account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data) self.account_data_for_user(account_data)
) )
@ -623,7 +677,6 @@ class SyncHandler(BaseHandler):
recents = yield self._filter_events_for_client( recents = yield self._filter_events_for_client(
sync_config.user.to_string(), sync_config.user.to_string(),
recents, recents,
is_peeking=sync_config.is_guest,
) )
else: else:
recents = [] recents = []
@ -645,7 +698,6 @@ class SyncHandler(BaseHandler):
loaded_recents = yield self._filter_events_for_client( loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(), sync_config.user.to_string(),
loaded_recents, loaded_recents,
is_peeking=sync_config.is_guest,
) )
loaded_recents.extend(recents) loaded_recents.extend(recents)
recents = loaded_recents recents = loaded_recents
@ -825,14 +877,20 @@ class SyncHandler(BaseHandler):
with Measure(self.clock, "compute_state_delta"): with Measure(self.clock, "compute_state_delta"):
if full_state: if full_state:
if batch: if batch:
current_state = yield self.store.get_state_for_event(
batch.events[-1].event_id
)
state = yield self.store.get_state_for_event( state = yield self.store.get_state_for_event(
batch.events[0].event_id batch.events[0].event_id
) )
else: else:
state = yield self.get_state_at( current_state = yield self.get_state_at(
room_id, stream_position=now_token room_id, stream_position=now_token
) )
state = current_state
timeline_state = { timeline_state = {
(event.type, event.state_key): event (event.type, event.state_key): event
for event in batch.events if event.is_state() for event in batch.events if event.is_state()
@ -842,12 +900,17 @@ class SyncHandler(BaseHandler):
timeline_contains=timeline_state, timeline_contains=timeline_state,
timeline_start=state, timeline_start=state,
previous={}, previous={},
current=current_state,
) )
elif batch.limited: elif batch.limited:
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token room_id, stream_position=since_token
) )
current_state = yield self.store.get_state_for_event(
batch.events[-1].event_id
)
state_at_timeline_start = yield self.store.get_state_for_event( state_at_timeline_start = yield self.store.get_state_for_event(
batch.events[0].event_id batch.events[0].event_id
) )
@ -861,6 +924,7 @@ class SyncHandler(BaseHandler):
timeline_contains=timeline_state, timeline_contains=timeline_state,
timeline_start=state_at_timeline_start, timeline_start=state_at_timeline_start,
previous=state_at_previous_sync, previous=state_at_previous_sync,
current=current_state,
) )
else: else:
state = {} state = {}
@ -920,7 +984,7 @@ def _action_has_highlight(actions):
return False return False
def _calculate_state(timeline_contains, timeline_start, previous): def _calculate_state(timeline_contains, timeline_start, previous, current):
"""Works out what state to include in a sync response. """Works out what state to include in a sync response.
Args: Args:
@ -928,6 +992,7 @@ def _calculate_state(timeline_contains, timeline_start, previous):
timeline_start (dict): state at the start of the timeline timeline_start (dict): state at the start of the timeline
previous (dict): state at the end of the previous sync (or empty dict previous (dict): state at the end of the previous sync (or empty dict
if this is an initial sync) if this is an initial sync)
current (dict): state at the end of the timeline
Returns: Returns:
dict dict
@ -938,14 +1003,16 @@ def _calculate_state(timeline_contains, timeline_start, previous):
timeline_contains.values(), timeline_contains.values(),
previous.values(), previous.values(),
timeline_start.values(), timeline_start.values(),
current.values(),
) )
} }
c_ids = set(e.event_id for e in current.values())
tc_ids = set(e.event_id for e in timeline_contains.values()) tc_ids = set(e.event_id for e in timeline_contains.values())
p_ids = set(e.event_id for e in previous.values()) p_ids = set(e.event_id for e in previous.values())
ts_ids = set(e.event_id for e in timeline_start.values()) ts_ids = set(e.event_id for e in timeline_start.values())
state_ids = (ts_ids - p_ids) - tc_ids state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
evs = (event_id_to_state[e] for e in state_ids) evs = (event_id_to_state[e] for e in state_ids)
return { return {

View File

@ -25,6 +25,7 @@ from synapse.types import UserID
import logging import logging
from collections import namedtuple from collections import namedtuple
import ujson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -219,6 +220,19 @@ class TypingNotificationHandler(BaseHandler):
"typing_key", self._latest_room_serial, rooms=[room_id] "typing_key", self._latest_room_serial, rooms=[room_id]
) )
def get_all_typing_updates(self, last_id, current_id):
# TODO: Work out a way to do this without scanning the entire state.
rows = []
for room_id, serial in self._room_serials.items():
if last_id < serial and serial <= current_id:
typing = self._room_typing[room_id]
typing_bytes = json.dumps([
u.to_string() for u in typing
], ensure_ascii=False)
rows.append((serial, room_id, typing_bytes))
rows.sort()
return rows
class TypingNotificationEventSource(object): class TypingNotificationEventSource(object):
def __init__(self, hs): def __init__(self, hs):

View File

@ -103,7 +103,7 @@ class SimpleHttpClient(object):
# TODO: Do we ever want to log message contents? # TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args) logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
response = yield self.request( response = yield self.request(
"POST", "POST",
@ -249,7 +249,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def post_urlencoded_get_raw(self, url, args={}): def post_urlencoded_get_raw(self, url, args={}):
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
response = yield self.request( response = yield self.request(
"POST", "POST",
@ -269,6 +269,19 @@ class CaptchaServerHttpClient(SimpleHttpClient):
defer.returnValue(e.response) defer.returnValue(e.response)
def encode_urlencode_args(args):
return {k: encode_urlencode_arg(v) for k, v in args.items()}
def encode_urlencode_arg(arg):
if isinstance(arg, unicode):
return arg.encode('utf-8')
elif isinstance(arg, list):
return [encode_urlencode_arg(i) for i in arg]
else:
return arg
def _print_ex(e): def _print_ex(e):
if hasattr(e, "reasons") and e.reasons: if hasattr(e, "reasons") and e.reasons:
for ex in e.reasons: for ex in e.reasons:

View File

@ -18,6 +18,7 @@ from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError, Codes cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError, Codes
) )
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches import intern_dict
import synapse.metrics import synapse.metrics
import synapse.events import synapse.events
@ -229,11 +230,12 @@ class JsonResource(HttpServer, resource.Resource):
else: else:
servlet_classname = "%r" % callback servlet_classname = "%r" % callback
args = [ kwargs = intern_dict({
urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups() name: urllib.unquote(value).decode("UTF-8") if value else value
] for name, value in m.groupdict().items()
})
callback_return = yield callback(request, *args) callback_return = yield callback(request, **kwargs)
if callback_return is not None: if callback_return is not None:
code, response = callback_return code, response = callback_return
self._send_response(request, code, response) self._send_response(request, code, response)
@ -367,10 +369,29 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
"Origin, X-Requested-With, Content-Type, Accept") "Origin, X-Requested-With, Content-Type, Accept")
request.write(json_bytes) request.write(json_bytes)
request.finish() finish_request(request)
return NOT_DONE_YET return NOT_DONE_YET
def finish_request(request):
""" Finish writing the response to the request.
Twisted throws a RuntimeException if the connection closed before the
response was written but doesn't provide a convenient or reliable way to
determine if the connection was closed. So we catch and log the RuntimeException
You might think that ``request.notifyFinish`` could be used to tell if the
request was finished. However the deferred it returns won't fire if the
connection was already closed, meaning we'd have to have called the method
right at the start of the request. By the time we want to write the response
it will already be too late.
"""
try:
request.finish()
except RuntimeError as e:
logger.info("Connection disconnected before response was written: %r", e)
def _request_user_agent_is_curl(request): def _request_user_agent_is_curl(request):
user_agents = request.requestHeaders.getRawHeaders( user_agents = request.requestHeaders.getRawHeaders(
"User-Agent", default=[] "User-Agent", default=[]

View File

@ -15,14 +15,27 @@
""" This module contains base REST classes for constructing REST servlets. """ """ This module contains base REST classes for constructing REST servlets. """
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, Codes
import logging import logging
import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def parse_integer(request, name, default=None, required=False): def parse_integer(request, name, default=None, required=False):
"""Parse an integer parameter from the request string
:param request: the twisted HTTP request.
:param name (str): the name of the query parameter.
:param default: value to use if the parameter is absent, defaults to None.
:param required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
:return: An int value or the default.
:raises
SynapseError if the parameter is absent and required, or if the
parameter is present and not an integer.
"""
if name in request.args: if name in request.args:
try: try:
return int(request.args[name][0]) return int(request.args[name][0])
@ -32,12 +45,25 @@ def parse_integer(request, name, default=None, required=False):
else: else:
if required: if required:
message = "Missing integer query parameter %r" % (name,) message = "Missing integer query parameter %r" % (name,)
raise SynapseError(400, message) raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else: else:
return default return default
def parse_boolean(request, name, default=None, required=False): def parse_boolean(request, name, default=None, required=False):
"""Parse a boolean parameter from the request query string
:param request: the twisted HTTP request.
:param name (str): the name of the query parameter.
:param default: value to use if the parameter is absent, defaults to None.
:param required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
:return: A bool value or the default.
:raises
SynapseError if the parameter is absent and required, or if the
parameter is present and not one of "true" or "false".
"""
if name in request.args: if name in request.args:
try: try:
return { return {
@ -53,30 +79,84 @@ def parse_boolean(request, name, default=None, required=False):
else: else:
if required: if required:
message = "Missing boolean query parameter %r" % (name,) message = "Missing boolean query parameter %r" % (name,)
raise SynapseError(400, message) raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else: else:
return default return default
def parse_string(request, name, default=None, required=False, def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"): allowed_values=None, param_type="string"):
"""Parse a string parameter from the request query string.
:param request: the twisted HTTP request.
:param name (str): the name of the query parameter.
:param default: value to use if the parameter is absent, defaults to None.
:param required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
:param allowed_values (list): List of allowed values for the string,
or None if any value is allowed, defaults to None
:return: A string value or the default.
:raises
SynapseError if the parameter is absent and required, or if the
parameter is present, must be one of a list of allowed values and
is not one of those allowed values.
"""
if name in request.args: if name in request.args:
value = request.args[name][0] value = request.args[name][0]
if allowed_values is not None and value not in allowed_values: if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % ( message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values) name, ", ".join(repr(v) for v in allowed_values)
) )
raise SynapseError(message) raise SynapseError(400, message)
else: else:
return value return value
else: else:
if required: if required:
message = "Missing %s query parameter %r" % (param_type, name) message = "Missing %s query parameter %r" % (param_type, name)
raise SynapseError(400, message) raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else: else:
return default return default
def parse_json_value_from_request(request):
"""Parse a JSON value from the body of a twisted HTTP request.
:param request: the twisted HTTP request.
:returns: The JSON value.
:raises
SynapseError if the request body couldn't be decoded as JSON.
"""
try:
content_bytes = request.content.read()
except:
raise SynapseError(400, "Error reading JSON content.")
try:
content = simplejson.loads(content_bytes)
except simplejson.JSONDecodeError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
return content
def parse_json_object_from_request(request):
"""Parse a JSON object from the body of a twisted HTTP request.
:param request: the twisted HTTP request.
:raises
SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object.
"""
content = parse_json_value_from_request(request)
if type(content) != dict:
message = "Content must be a JSON object."
raise SynapseError(400, message, errcode=Codes.BAD_JSON)
return content
class RestServlet(object): class RestServlet(object):
""" A Synapse REST Servlet. """ A Synapse REST Servlet.

View File

@ -159,6 +159,8 @@ class Notifier(object):
self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
) )
self.replication_deferred = ObservableDeferred(defer.Deferred())
# This is not a very cheap test to perform, but it's only executed # This is not a very cheap test to perform, but it's only executed
# when rendering the metrics page, which is likely once per minute at # when rendering the metrics page, which is likely once per minute at
# most when scraping it. # most when scraping it.
@ -207,6 +209,8 @@ class Notifier(object):
)) ))
self._notify_pending_new_room_events(max_room_stream_id) self._notify_pending_new_room_events(max_room_stream_id)
self.notify_replication()
def _notify_pending_new_room_events(self, max_room_stream_id): def _notify_pending_new_room_events(self, max_room_stream_id):
"""Notify for the room events that were queued waiting for a previous """Notify for the room events that were queued waiting for a previous
event to be persisted. event to be persisted.
@ -276,9 +280,17 @@ class Notifier(object):
except: except:
logger.exception("Failed to notify listener") logger.exception("Failed to notify listener")
self.notify_replication()
def on_new_replication_data(self):
"""Used to inform replication listeners that something has happend
without waking up any of the normal user event streams"""
with PreserveLoggingContext():
self.notify_replication()
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user_id, timeout, callback, room_ids=None, def wait_for_events(self, user_id, timeout, callback, room_ids=None,
from_token=StreamToken("s0", "0", "0", "0", "0")): from_token=StreamToken.START):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """
@ -479,3 +491,45 @@ class Notifier(object):
room_streams = self.room_to_user_streams.setdefault(room_id, set()) room_streams = self.room_to_user_streams.setdefault(room_id, set())
room_streams.add(new_user_stream) room_streams.add(new_user_stream)
new_user_stream.rooms.add(room_id) new_user_stream.rooms.add(room_id)
def notify_replication(self):
"""Notify the any replication listeners that there's a new event"""
with PreserveLoggingContext():
deferred = self.replication_deferred
self.replication_deferred = ObservableDeferred(defer.Deferred())
deferred.callback(None)
@defer.inlineCallbacks
def wait_for_replication(self, callback, timeout):
"""Wait for an event to happen.
:param callback:
Gets called whenever an event happens. If this returns a truthy
value then ``wait_for_replication`` returns, otherwise it waits
for another event.
:param int timeout:
How many milliseconds to wait for callback return a truthy value.
:returns:
A deferred that resolves with the value returned by the callback.
"""
listener = _NotificationListener(None)
def timed_out():
listener.deferred.cancel()
timer = self.clock.call_later(timeout / 1000., timed_out)
while True:
listener.deferred = self.replication_deferred.observe()
result = yield callback()
if result:
break
try:
with PreserveLoggingContext():
yield listener.deferred
except defer.CancelledError:
break
self.clock.cancel_call_later(timer, ignore_errs=True)
defer.returnValue(result)

View File

@ -21,7 +21,7 @@ from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
import synapse.util.async import synapse.util.async
import push_rule_evaluator as push_rule_evaluator from .push_rule_evaluator import evaluator_for_user_id
import logging import logging
import random import random
@ -47,14 +47,13 @@ class Pusher(object):
MAX_BACKOFF = 60 * 60 * 1000 MAX_BACKOFF = 60 * 60 * 1000
GIVE_UP_AFTER = 24 * 60 * 60 * 1000 GIVE_UP_AFTER = 24 * 60 * 60 * 1000
def __init__(self, _hs, profile_tag, user_id, app_id, def __init__(self, _hs, user_id, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts, app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since): data, last_token, last_success, failing_since):
self.hs = _hs self.hs = _hs
self.evStreamHandler = self.hs.get_handlers().event_stream_handler self.evStreamHandler = self.hs.get_handlers().event_stream_handler
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.profile_tag = profile_tag
self.user_id = user_id self.user_id = user_id
self.app_id = app_id self.app_id = app_id
self.app_display_name = app_display_name self.app_display_name = app_display_name
@ -186,8 +185,8 @@ class Pusher(object):
processed = False processed = False
rule_evaluator = yield \ rule_evaluator = yield \
push_rule_evaluator.evaluator_for_user_id_and_profile_tag( evaluator_for_user_id(
self.user_id, self.profile_tag, single_event['room_id'], self.store self.user_id, single_event['room_id'], self.store
) )
actions = yield rule_evaluator.actions_for_event(single_event) actions = yield rule_evaluator.actions_for_event(single_event)
@ -318,7 +317,7 @@ class Pusher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_badge_count(self): def _get_badge_count(self):
invites, joins = yield defer.gatherResults([ invites, joins = yield defer.gatherResults([
self.store.get_invites_for_user(self.user_id), self.store.get_invited_rooms_for_user(self.user_id),
self.store.get_rooms_for_user(self.user_id), self.store.get_rooms_for_user(self.user_id),
], consumeErrors=True) ], consumeErrors=True)

View File

@ -15,7 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
import bulk_push_rule_evaluator from .bulk_push_rule_evaluator import evaluator_for_room_id
import logging import logging
@ -35,7 +35,7 @@ class ActionGenerator:
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context, handler): def handle_push_actions_for_event(self, event, context, handler):
bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id( bulk_evaluator = yield evaluator_for_room_id(
event.room_id, self.hs, self.store event.room_id, self.hs, self.store
) )
@ -44,5 +44,5 @@ class ActionGenerator:
) )
context.push_actions = [ context.push_actions = [
(uid, None, actions) for uid, actions in actions_by_user.items() (uid, actions) for uid, actions in actions_by_user.items()
] ]

View File

@ -13,46 +13,67 @@
# limitations under the License. # limitations under the License.
from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
import copy
def list_with_base_rules(rawrules): def list_with_base_rules(rawrules):
"""Combine the list of rules set by the user with the default push rules
:param list rawrules: The rules the user has modified or set.
:returns: A new list with the rules set by the user combined with the
defaults.
"""
ruleslist = [] ruleslist = []
# Grab the base rules that the user has modified.
# The modified base rules have a priority_class of -1.
modified_base_rules = {
r['rule_id']: r for r in rawrules if r['priority_class'] < 0
}
# Remove the modified base rules from the list, They'll be added back
# in the default postions in the list.
rawrules = [r for r in rawrules if r['priority_class'] >= 0]
# shove the server default rules for each kind onto the end of each # shove the server default rules for each kind onto the end of each
current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1] current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
ruleslist.extend(make_base_prepend_rules( ruleslist.extend(make_base_prepend_rules(
PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
)) ))
for r in rawrules: for r in rawrules:
if r['priority_class'] < current_prio_class: if r['priority_class'] < current_prio_class:
while r['priority_class'] < current_prio_class: while r['priority_class'] < current_prio_class:
ruleslist.extend(make_base_append_rules( ruleslist.extend(make_base_append_rules(
PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
modified_base_rules,
)) ))
current_prio_class -= 1 current_prio_class -= 1
if current_prio_class > 0: if current_prio_class > 0:
ruleslist.extend(make_base_prepend_rules( ruleslist.extend(make_base_prepend_rules(
PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
modified_base_rules,
)) ))
ruleslist.append(r) ruleslist.append(r)
while current_prio_class > 0: while current_prio_class > 0:
ruleslist.extend(make_base_append_rules( ruleslist.extend(make_base_append_rules(
PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
modified_base_rules,
)) ))
current_prio_class -= 1 current_prio_class -= 1
if current_prio_class > 0: if current_prio_class > 0:
ruleslist.extend(make_base_prepend_rules( ruleslist.extend(make_base_prepend_rules(
PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
modified_base_rules,
)) ))
return ruleslist return ruleslist
def make_base_append_rules(kind): def make_base_append_rules(kind, modified_base_rules):
rules = [] rules = []
if kind == 'override': if kind == 'override':
@ -62,15 +83,31 @@ def make_base_append_rules(kind):
elif kind == 'content': elif kind == 'content':
rules = BASE_APPEND_CONTENT_RULES rules = BASE_APPEND_CONTENT_RULES
# Copy the rules before modifying them
rules = copy.deepcopy(rules)
for r in rules:
# Only modify the actions, keep the conditions the same.
modified = modified_base_rules.get(r['rule_id'])
if modified:
r['actions'] = modified['actions']
return rules return rules
def make_base_prepend_rules(kind): def make_base_prepend_rules(kind, modified_base_rules):
rules = [] rules = []
if kind == 'override': if kind == 'override':
rules = BASE_PREPEND_OVERRIDE_RULES rules = BASE_PREPEND_OVERRIDE_RULES
# Copy the rules before modifying them
rules = copy.deepcopy(rules)
for r in rules:
# Only modify the actions, keep the conditions the same.
modified = modified_base_rules.get(r['rule_id'])
if modified:
r['actions'] = modified['actions']
return rules return rules
@ -263,18 +300,24 @@ BASE_APPEND_UNDERRIDE_RULES = [
] ]
BASE_RULE_IDS = set()
for r in BASE_APPEND_CONTENT_RULES: for r in BASE_APPEND_CONTENT_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['content'] r['priority_class'] = PRIORITY_CLASS_MAP['content']
r['default'] = True r['default'] = True
BASE_RULE_IDS.add(r['rule_id'])
for r in BASE_PREPEND_OVERRIDE_RULES: for r in BASE_PREPEND_OVERRIDE_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['override'] r['priority_class'] = PRIORITY_CLASS_MAP['override']
r['default'] = True r['default'] = True
BASE_RULE_IDS.add(r['rule_id'])
for r in BASE_APPEND_OVRRIDE_RULES: for r in BASE_APPEND_OVRRIDE_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['override'] r['priority_class'] = PRIORITY_CLASS_MAP['override']
r['default'] = True r['default'] = True
BASE_RULE_IDS.add(r['rule_id'])
for r in BASE_APPEND_UNDERRIDE_RULES: for r in BASE_APPEND_UNDERRIDE_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['underride'] r['priority_class'] = PRIORITY_CLASS_MAP['underride']
r['default'] = True r['default'] = True
BASE_RULE_IDS.add(r['rule_id'])

View File

@ -18,8 +18,8 @@ import ujson as json
from twisted.internet import defer from twisted.internet import defer
import baserules from .baserules import list_with_base_rules
from push_rule_evaluator import PushRuleEvaluatorForEvent from .push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@ -39,7 +39,7 @@ def _get_rules(room_id, user_ids, store):
rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids) rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids)
rules_by_user = { rules_by_user = {
uid: baserules.list_with_base_rules([ uid: list_with_base_rules([
decode_rule_json(rule_list) decode_rule_json(rule_list)
for rule_list in rules_by_user.get(uid, []) for rule_list in rules_by_user.get(uid, [])
]) ])
@ -103,11 +103,13 @@ class BulkPushRuleEvaluator:
users_dict = yield self.store.are_guests(self.rules_by_user.keys()) users_dict = yield self.store.are_guests(self.rules_by_user.keys())
filtered_by_user = yield handler._filter_events_for_clients( filtered_by_user = yield handler.filter_events_for_clients(
users_dict.items(), [event], {event.event_id: current_state} users_dict.items(), [event], {event.event_id: current_state}
) )
evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room)) room_members = yield self.store.get_users_in_room(self.room_id)
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
condition_cache = {} condition_cache = {}
@ -152,7 +154,7 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache):
elif res is True: elif res is True:
continue continue
res = evaluator.matches(cond, uid, display_name, None) res = evaluator.matches(cond, uid, display_name)
if _id: if _id:
cache[_id] = bool(res) cache[_id] = bool(res)

View File

@ -0,0 +1,112 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.push.baserules import list_with_base_rules
from synapse.push.rulekinds import (
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
)
import copy
import simplejson as json
def format_push_rules_for_user(user, rawrules, enabled_map):
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(list_with_base_rules(ruleslist))
rules = {'global': {}, 'device': {}}
rules['global'] = _add_empty_priority_class_arrays(rules['global'])
for r in ruleslist:
rulearray = None
template_name = _priority_class_to_template_name(r['priority_class'])
# Remove internal stuff.
for c in r["conditions"]:
c.pop("_id", None)
pattern_type = c.pop("pattern_type", None)
if pattern_type == "user_id":
c["pattern"] = user.to_string()
elif pattern_type == "user_localpart":
c["pattern"] = user.localpart
rulearray = rules['global'][template_name]
template_rule = _rule_to_template(r)
if template_rule:
if r['rule_id'] in enabled_map:
template_rule['enabled'] = enabled_map[r['rule_id']]
elif 'enabled' in r:
template_rule['enabled'] = r['enabled']
else:
template_rule['enabled'] = True
rulearray.append(template_rule)
return rules
def _add_empty_priority_class_arrays(d):
for pc in PRIORITY_CLASS_MAP.keys():
d[pc] = []
return d
def _rule_to_template(rule):
unscoped_rule_id = None
if 'rule_id' in rule:
unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id'])
template_name = _priority_class_to_template_name(rule['priority_class'])
if template_name in ['override', 'underride']:
templaterule = {k: rule[k] for k in ["conditions", "actions"]}
elif template_name in ["sender", "room"]:
templaterule = {'actions': rule['actions']}
unscoped_rule_id = rule['conditions'][0]['pattern']
elif template_name == 'content':
if len(rule["conditions"]) != 1:
return None
thecond = rule["conditions"][0]
if "pattern" not in thecond:
return None
templaterule = {'actions': rule['actions']}
templaterule["pattern"] = thecond["pattern"]
if unscoped_rule_id:
templaterule['rule_id'] = unscoped_rule_id
if 'default' in rule:
templaterule['default'] = rule['default']
return templaterule
def _rule_id_from_namespaced(in_rule_id):
return in_rule_id.split('/')[-1]
def _priority_class_to_template_name(pc):
return PRIORITY_CLASS_INVERSE_MAP[pc]

View File

@ -23,12 +23,11 @@ logger = logging.getLogger(__name__)
class HttpPusher(Pusher): class HttpPusher(Pusher):
def __init__(self, _hs, profile_tag, user_id, app_id, def __init__(self, _hs, user_id, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts, app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since): data, last_token, last_success, failing_since):
super(HttpPusher, self).__init__( super(HttpPusher, self).__init__(
_hs, _hs,
profile_tag,
user_id, user_id,
app_id, app_id,
app_display_name, app_display_name,

View File

@ -15,7 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
import baserules from .baserules import list_with_base_rules
import logging import logging
import simplejson as json import simplejson as json
@ -33,7 +33,7 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store): def evaluator_for_user_id(user_id, room_id, store):
rawrules = yield store.get_push_rules_for_user(user_id) rawrules = yield store.get_push_rules_for_user(user_id)
enabled_map = yield store.get_push_rules_enabled_for_user(user_id) enabled_map = yield store.get_push_rules_enabled_for_user(user_id)
our_member_event = yield store.get_current_state( our_member_event = yield store.get_current_state(
@ -43,7 +43,7 @@ def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
) )
defer.returnValue(PushRuleEvaluator( defer.returnValue(PushRuleEvaluator(
user_id, profile_tag, rawrules, enabled_map, user_id, rawrules, enabled_map,
room_id, our_member_event, store room_id, our_member_event, store
)) ))
@ -77,10 +77,9 @@ def _room_member_count(ev, condition, room_member_count):
class PushRuleEvaluator: class PushRuleEvaluator:
DEFAULT_ACTIONS = [] DEFAULT_ACTIONS = []
def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id, def __init__(self, user_id, raw_rules, enabled_map, room_id,
our_member_event, store): our_member_event, store):
self.user_id = user_id self.user_id = user_id
self.profile_tag = profile_tag
self.room_id = room_id self.room_id = room_id
self.our_member_event = our_member_event self.our_member_event = our_member_event
self.store = store self.store = store
@ -92,7 +91,7 @@ class PushRuleEvaluator:
rule['actions'] = json.loads(raw_rule['actions']) rule['actions'] = json.loads(raw_rule['actions'])
rules.append(rule) rules.append(rule)
self.rules = baserules.list_with_base_rules(rules) self.rules = list_with_base_rules(rules)
self.enabled_map = enabled_map self.enabled_map = enabled_map
@ -152,7 +151,7 @@ class PushRuleEvaluator:
matches = True matches = True
for c in conditions: for c in conditions:
matches = evaluator.matches( matches = evaluator.matches(
c, self.user_id, my_display_name, self.profile_tag c, self.user_id, my_display_name
) )
if not matches: if not matches:
break break
@ -189,13 +188,9 @@ class PushRuleEvaluatorForEvent(object):
# Maps strings of e.g. 'content.body' -> event["content"]["body"] # Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event) self._value_cache = _flatten_dict(event)
def matches(self, condition, user_id, display_name, profile_tag): def matches(self, condition, user_id, display_name):
if condition['kind'] == 'event_match': if condition['kind'] == 'event_match':
return self._event_match(condition, user_id) return self._event_match(condition, user_id)
elif condition['kind'] == 'device':
if 'profile_tag' not in condition:
return True
return condition['profile_tag'] == profile_tag
elif condition['kind'] == 'contains_display_name': elif condition['kind'] == 'contains_display_name':
return self._contains_display_name(display_name) return self._contains_display_name(display_name)
elif condition['kind'] == 'room_member_count': elif condition['kind'] == 'room_member_count':

View File

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from httppusher import HttpPusher from .httppusher import HttpPusher
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn
@ -29,6 +29,7 @@ class PusherPool:
def __init__(self, _hs): def __init__(self, _hs):
self.hs = _hs self.hs = _hs
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.pushers = {} self.pushers = {}
self.last_pusher_started = -1 self.last_pusher_started = -1
@ -38,8 +39,11 @@ class PusherPool:
self._start_pushers(pushers) self._start_pushers(pushers)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_pusher(self, user_id, access_token, profile_tag, kind, app_id, def add_pusher(self, user_id, access_token, kind, app_id,
app_display_name, device_display_name, pushkey, lang, data): app_display_name, device_display_name, pushkey, lang, data,
profile_tag=""):
time_now_msec = self.clock.time_msec()
# we try to create the pusher just to validate the config: it # we try to create the pusher just to validate the config: it
# will then get pulled out of the database, # will then get pulled out of the database,
# recreated, added and started: this means we have only one # recreated, added and started: this means we have only one
@ -47,23 +51,31 @@ class PusherPool:
self._create_pusher({ self._create_pusher({
"user_name": user_id, "user_name": user_id,
"kind": kind, "kind": kind,
"profile_tag": profile_tag,
"app_id": app_id, "app_id": app_id,
"app_display_name": app_display_name, "app_display_name": app_display_name,
"device_display_name": device_display_name, "device_display_name": device_display_name,
"pushkey": pushkey, "pushkey": pushkey,
"ts": self.hs.get_clock().time_msec(), "ts": time_now_msec,
"lang": lang, "lang": lang,
"data": data, "data": data,
"last_token": None, "last_token": None,
"last_success": None, "last_success": None,
"failing_since": None "failing_since": None
}) })
yield self._add_pusher_to_store( yield self.store.add_pusher(
user_id, access_token, profile_tag, kind, app_id, user_id=user_id,
app_display_name, device_display_name, access_token=access_token,
pushkey, lang, data kind=kind,
app_id=app_id,
app_display_name=app_display_name,
device_display_name=device_display_name,
pushkey=pushkey,
pushkey_ts=time_now_msec,
lang=lang,
data=data,
profile_tag=profile_tag,
) )
yield self._refresh_pusher(app_id, pushkey, user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey, def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey,
@ -80,44 +92,24 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_user(self, user_id): def remove_pushers_by_user(self, user_id, except_token_ids=[]):
all = yield self.store.get_all_pushers() all = yield self.store.get_all_pushers()
logger.info( logger.info(
"Removing all pushers for user %s", "Removing all pushers for user %s except access tokens ids %r",
user_id, user_id, except_token_ids
) )
for p in all: for p in all:
if p['user_name'] == user_id: if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name'] p['app_id'], p['pushkey'], p['user_name']
) )
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, lang, data):
yield self.store.add_pusher(
user_id=user_id,
access_token=access_token,
profile_tag=profile_tag,
kind=kind,
app_id=app_id,
app_display_name=app_display_name,
device_display_name=device_display_name,
pushkey=pushkey,
pushkey_ts=self.hs.get_clock().time_msec(),
lang=lang,
data=data,
)
yield self._refresh_pusher(app_id, pushkey, user_id)
def _create_pusher(self, pusherdict): def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http': if pusherdict['kind'] == 'http':
return HttpPusher( return HttpPusher(
self.hs, self.hs,
profile_tag=pusherdict['profile_tag'],
user_id=pusherdict['user_name'], user_id=pusherdict['user_name'],
app_id=pusherdict['app_id'], app_id=pusherdict['app_id'],
app_display_name=pusherdict['app_display_name'], app_display_name=pusherdict['app_display_name'],

View File

@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"frozendict>=0.4": ["frozendict"], "frozendict>=0.4": ["frozendict"],
"unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"], "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
"signedjson>=1.0.0": ["signedjson>=1.0.0"], "signedjson>=1.0.0": ["signedjson>=1.0.0"],
"pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"], "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"],
@ -34,7 +34,7 @@ REQUIREMENTS = {
"pydenticon": ["pydenticon"], "pydenticon": ["pydenticon"],
"ujson": ["ujson"], "ujson": ["ujson"],
"blist": ["blist"], "blist": ["blist"],
"pysaml2": ["saml2"], "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
"pymacaroons-pynacl": ["pymacaroons"], "pymacaroons-pynacl": ["pymacaroons"],
} }
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {

View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,367 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.server import request_handler, finish_request
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
import ujson as json
import collections
import logging
logger = logging.getLogger(__name__)
REPLICATION_PREFIX = "/_synapse/replication"
STREAM_NAMES = (
("events",),
("presence",),
("typing",),
("receipts",),
("user_account_data", "room_account_data", "tag_account_data",),
("backfill",),
("push_rules",),
("pushers",),
)
class ReplicationResource(Resource):
"""
HTTP endpoint for extracting data from synapse.
The streams of data returned by the endpoint are controlled by the
parameters given to the API. To return a given stream pass a query
parameter with a position in the stream to return data from or the
special value "-1" to return data from the start of the stream.
If there is no data for any of the supplied streams after the given
position then the request will block until there is data for one
of the streams. This allows clients to long-poll this API.
The possible streams are:
* "streams": A special stream returing the positions of other streams.
* "events": The new events seen on the server.
* "presence": Presence updates.
* "typing": Typing updates.
* "receipts": Receipt updates.
* "user_account_data": Top-level per user account data.
* "room_account_data: Per room per user account data.
* "tag_account_data": Per room per user tags.
* "backfill": Old events that have been backfilled from other servers.
* "push_rules": Per user changes to push rules.
* "pushers": Per user changes to their pushers.
The API takes two additional query parameters:
* "timeout": How long to wait before returning an empty response.
* "limit": The maximum number of rows to return for the selected streams.
The response is a JSON object with keys for each stream with updates. Under
each key is a JSON object with:
* "postion": The current position of the stream.
* "field_names": The names of the fields in each row.
* "rows": The updates as an array of arrays.
There are a number of ways this API could be used:
1) To replicate the contents of the backing database to another database.
2) To be notified when the contents of a shared backing database changes.
3) To "tail" the activity happening on a server for debugging.
In the first case the client would track all of the streams and store it's
own copy of the data.
In the second case the client might theoretically just be able to follow
the "streams" stream to track where the other streams are. However in
practise it will probably need to get the contents of the streams in
order to expire the any in-memory caches. Whether it gets the contents
of the streams from this replication API or directly from the backing
store is a matter of taste.
In the third case the client would use the "streams" stream to find what
streams are available and their current positions. Then it can start
long-polling this replication API for new data on those streams.
"""
isLeaf = True
def __init__(self, hs):
Resource.__init__(self) # Resource is old-style, so no super()
self.version_string = hs.version_string
self.store = hs.get_datastore()
self.sources = hs.get_event_sources()
self.presence_handler = hs.get_handlers().presence_handler
self.typing_handler = hs.get_handlers().typing_notification_handler
self.notifier = hs.notifier
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@defer.inlineCallbacks
def current_replication_token(self):
stream_token = yield self.sources.get_current_token()
backfill_token = yield self.store.get_current_backfill_token()
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token()
defer.returnValue(_ReplicationToken(
room_stream_token,
int(stream_token.presence_key),
int(stream_token.typing_key),
int(stream_token.receipt_key),
int(stream_token.account_data_key),
backfill_token,
push_rules_token,
pushers_token,
))
@request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
limit = parse_integer(request, "limit", 100)
timeout = parse_integer(request, "timeout", 10 * 1000)
request.setHeader(b"Content-Type", b"application/json")
writer = _Writer(request)
@defer.inlineCallbacks
def replicate():
current_token = yield self.current_replication_token()
logger.info("Replicating up to %r", current_token)
yield self.account_data(writer, current_token, limit)
yield self.events(writer, current_token, limit)
yield self.presence(writer, current_token) # TODO: implement limit
yield self.typing(writer, current_token) # TODO: implement limit
yield self.receipts(writer, current_token, limit)
yield self.push_rules(writer, current_token, limit)
yield self.pushers(writer, current_token, limit)
self.streams(writer, current_token)
logger.info("Replicated %d rows", writer.total)
defer.returnValue(writer.total)
yield self.notifier.wait_for_replication(replicate, timeout)
writer.finish()
def streams(self, writer, current_token):
request_token = parse_string(writer.request, "streams")
streams = []
if request_token is not None:
if request_token == "-1":
for names, position in zip(STREAM_NAMES, current_token):
streams.extend((name, position) for name in names)
else:
items = zip(
STREAM_NAMES,
current_token,
_ReplicationToken(request_token)
)
for names, current_id, last_id in items:
if last_id < current_id:
streams.extend((name, current_id) for name in names)
if streams:
writer.write_header_and_rows(
"streams", streams, ("name", "position"),
position=str(current_token)
)
@defer.inlineCallbacks
def events(self, writer, current_token, limit):
request_events = parse_integer(writer.request, "events")
request_backfill = parse_integer(writer.request, "backfill")
if request_events is not None or request_backfill is not None:
if request_events is None:
request_events = current_token.events
if request_backfill is None:
request_backfill = current_token.backfill
events_rows, backfill_rows = yield self.store.get_all_new_events(
request_backfill, request_events,
current_token.backfill, current_token.events,
limit
)
writer.write_header_and_rows(
"events", events_rows, ("position", "internal", "json")
)
writer.write_header_and_rows(
"backfill", backfill_rows, ("position", "internal", "json")
)
@defer.inlineCallbacks
def presence(self, writer, current_token):
current_position = current_token.presence
request_presence = parse_integer(writer.request, "presence")
if request_presence is not None:
presence_rows = yield self.presence_handler.get_all_presence_updates(
request_presence, current_position
)
writer.write_header_and_rows("presence", presence_rows, (
"position", "user_id", "state", "last_active_ts",
"last_federation_update_ts", "last_user_sync_ts",
"status_msg", "currently_active",
))
@defer.inlineCallbacks
def typing(self, writer, current_token):
current_position = current_token.presence
request_typing = parse_integer(writer.request, "typing")
if request_typing is not None:
typing_rows = yield self.typing_handler.get_all_typing_updates(
request_typing, current_position
)
writer.write_header_and_rows("typing", typing_rows, (
"position", "room_id", "typing"
))
@defer.inlineCallbacks
def receipts(self, writer, current_token, limit):
current_position = current_token.receipts
request_receipts = parse_integer(writer.request, "receipts")
if request_receipts is not None:
receipts_rows = yield self.store.get_all_updated_receipts(
request_receipts, current_position, limit
)
writer.write_header_and_rows("receipts", receipts_rows, (
"position", "room_id", "receipt_type", "user_id", "event_id", "data"
))
@defer.inlineCallbacks
def account_data(self, writer, current_token, limit):
current_position = current_token.account_data
user_account_data = parse_integer(writer.request, "user_account_data")
room_account_data = parse_integer(writer.request, "room_account_data")
tag_account_data = parse_integer(writer.request, "tag_account_data")
if user_account_data is not None or room_account_data is not None:
if user_account_data is None:
user_account_data = current_position
if room_account_data is None:
room_account_data = current_position
user_rows, room_rows = yield self.store.get_all_updated_account_data(
user_account_data, room_account_data, current_position, limit
)
writer.write_header_and_rows("user_account_data", user_rows, (
"position", "user_id", "type", "content"
))
writer.write_header_and_rows("room_account_data", room_rows, (
"position", "user_id", "room_id", "type", "content"
))
if tag_account_data is not None:
tag_rows = yield self.store.get_all_updated_tags(
tag_account_data, current_position, limit
)
writer.write_header_and_rows("tag_account_data", tag_rows, (
"position", "user_id", "room_id", "tags"
))
@defer.inlineCallbacks
def push_rules(self, writer, current_token, limit):
current_position = current_token.push_rules
push_rules = parse_integer(writer.request, "push_rules")
if push_rules is not None:
rows = yield self.store.get_all_push_rule_updates(
push_rules, current_position, limit
)
writer.write_header_and_rows("push_rules", rows, (
"position", "event_stream_ordering", "user_id", "rule_id", "op",
"priority_class", "priority", "conditions", "actions"
))
@defer.inlineCallbacks
def pushers(self, writer, current_token, limit):
current_position = current_token.pushers
pushers = parse_integer(writer.request, "pushers")
if pushers is not None:
updated, deleted = yield self.store.get_all_updated_pushers(
pushers, current_position, limit
)
writer.write_header_and_rows("pushers", updated, (
"position", "user_id", "access_token", "profile_tag", "kind",
"app_id", "app_display_name", "device_display_name", "pushkey",
"ts", "lang", "data"
))
writer.write_header_and_rows("deleted", deleted, (
"position", "user_id", "app_id", "pushkey"
))
class _Writer(object):
"""Writes the streams as a JSON object as the response to the request"""
def __init__(self, request):
self.streams = {}
self.request = request
self.total = 0
def write_header_and_rows(self, name, rows, fields, position=None):
if not rows:
return
if position is None:
position = rows[-1][0]
self.streams[name] = {
"position": str(position),
"field_names": fields,
"rows": rows,
}
self.total += len(rows)
def finish(self):
self.request.write(json.dumps(self.streams, ensure_ascii=False))
finish_request(self.request)
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers"
))):
__slots__ = []
def __new__(cls, *args):
if len(args) == 1:
streams = [int(value) for value in args[0].split("_")]
if len(streams) < len(cls._fields):
streams.extend([0] * (len(cls._fields) - len(streams)))
return cls(*streams)
else:
return super(_ReplicationToken, cls).__new__(cls, *args)
def __str__(self):
return "_".join(str(value) for value in self)

View File

@ -30,6 +30,7 @@ from synapse.rest.client.v1 import (
push_rule, push_rule,
register as v1_register, register as v1_register,
login as v1_login, login as v1_login,
logout,
) )
from synapse.rest.client.v2_alpha import ( from synapse.rest.client.v2_alpha import (
@ -72,6 +73,7 @@ class ClientRestResource(JsonResource):
admin.register_servlets(hs, client_resource) admin.register_servlets(hs, client_resource)
pusher.register_servlets(hs, client_resource) pusher.register_servlets(hs, client_resource)
push_rule.register_servlets(hs, client_resource) push_rule.register_servlets(hs, client_resource)
logout.register_servlets(hs, client_resource)
# "v2" # "v2"
sync.register_servlets(hs, client_resource) sync.register_servlets(hs, client_resource)

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.types import UserID from synapse.types import UserID
from base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import logging import logging

View File

@ -18,9 +18,10 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError, Codes from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.types import RoomAlias from synapse.types import RoomAlias
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json
import logging import logging
@ -29,6 +30,7 @@ logger = logging.getLogger(__name__)
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
ClientDirectoryServer(hs).register(http_server) ClientDirectoryServer(hs).register(http_server)
ClientDirectoryListServer(hs).register(http_server)
class ClientDirectoryServer(ClientV1RestServlet): class ClientDirectoryServer(ClientV1RestServlet):
@ -45,7 +47,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_alias): def on_PUT(self, request, room_alias):
content = _parse_json(request) content = parse_json_object_from_request(request)
if "room_id" not in content: if "room_id" not in content:
raise SynapseError(400, "Missing room_id key", raise SynapseError(400, "Missing room_id key",
errcode=Codes.BAD_JSON) errcode=Codes.BAD_JSON)
@ -75,7 +77,11 @@ class ClientDirectoryServer(ClientV1RestServlet):
yield dir_handler.create_association( yield dir_handler.create_association(
user_id, room_alias, room_id, servers user_id, room_alias, room_id, servers
) )
yield dir_handler.send_room_alias_update_event(user_id, room_id) yield dir_handler.send_room_alias_update_event(
requester,
user_id,
room_id
)
except SynapseError as e: except SynapseError as e:
raise e raise e
except: except:
@ -118,15 +124,13 @@ class ClientDirectoryServer(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = requester.user user = requester.user
is_admin = yield self.auth.is_server_admin(user)
if not is_admin:
raise AuthError(403, "You need to be a server admin")
room_alias = RoomAlias.from_string(room_alias) room_alias = RoomAlias.from_string(room_alias)
yield dir_handler.delete_association( yield dir_handler.delete_association(
user.to_string(), room_alias requester, user.to_string(), room_alias
) )
logger.info( logger.info(
"User %s deleted alias %s", "User %s deleted alias %s",
user.to_string(), user.to_string(),
@ -136,12 +140,42 @@ class ClientDirectoryServer(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
def _parse_json(request): class ClientDirectoryListServer(ClientV1RestServlet):
try: PATTERNS = client_path_patterns("/directory/list/room/(?P<room_id>[^/]*)$")
content = json.loads(request.content.read())
if type(content) != dict: def __init__(self, hs):
raise SynapseError(400, "Content must be a JSON object.", super(ClientDirectoryListServer, self).__init__(hs)
errcode=Codes.NOT_JSON) self.store = hs.get_datastore()
return content
except ValueError: @defer.inlineCallbacks
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) def on_GET(self, request, room_id):
room = yield self.store.get_room(room_id)
if room is None:
raise SynapseError(400, "Unknown room")
defer.returnValue((200, {
"visibility": "public" if room["is_public"] else "private"
}))
@defer.inlineCallbacks
def on_PUT(self, request, room_id):
requester = yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
visibility = content.get("visibility", "public")
yield self.handlers.directory_handler.edit_published_room_list(
requester, room_id, visibility,
)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_DELETE(self, request, room_id):
requester = yield self.auth.get_user_by_req(request)
yield self.handlers.directory_handler.edit_published_room_list(
requester, room_id, "private",
)
defer.returnValue((200, {}))

View File

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
# TODO: Needs unit testing # TODO: Needs unit testing

View File

@ -17,7 +17,10 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, LoginError, Codes from synapse.api.errors import SynapseError, LoginError, Codes
from synapse.types import UserID from synapse.types import UserID
from base import ClientV1RestServlet, client_path_patterns from synapse.http.server import finish_request
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json import simplejson as json
import urllib import urllib
@ -77,7 +80,7 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
login_submission = _parse_json(request) login_submission = parse_json_object_from_request(request)
try: try:
if login_submission["type"] == LoginRestServlet.PASS_TYPE: if login_submission["type"] == LoginRestServlet.PASS_TYPE:
if not self.password_enabled: if not self.password_enabled:
@ -250,7 +253,7 @@ class SAML2RestServlet(ClientV1RestServlet):
SP = Saml2Client(conf) SP = Saml2Client(conf)
saml2_auth = SP.parse_authn_request_response( saml2_auth = SP.parse_authn_request_response(
request.args['SAMLResponse'][0], BINDING_HTTP_POST) request.args['SAMLResponse'][0], BINDING_HTTP_POST)
except Exception, e: # Not authenticated except Exception as e: # Not authenticated
logger.exception(e) logger.exception(e)
if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed: if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
username = saml2_auth.name_id.text username = saml2_auth.name_id.text
@ -263,7 +266,7 @@ class SAML2RestServlet(ClientV1RestServlet):
'?status=authenticated&access_token=' + '?status=authenticated&access_token=' +
token + '&user_id=' + user_id + '&ava=' + token + '&user_id=' + user_id + '&ava=' +
urllib.quote(json.dumps(saml2_auth.ava))) urllib.quote(json.dumps(saml2_auth.ava)))
request.finish() finish_request(request)
defer.returnValue(None) defer.returnValue(None)
defer.returnValue((200, {"status": "authenticated", defer.returnValue((200, {"status": "authenticated",
"user_id": user_id, "token": token, "user_id": user_id, "token": token,
@ -272,7 +275,7 @@ class SAML2RestServlet(ClientV1RestServlet):
request.redirect(urllib.unquote( request.redirect(urllib.unquote(
request.args['RelayState'][0]) + request.args['RelayState'][0]) +
'?status=not_authenticated') '?status=not_authenticated')
request.finish() finish_request(request)
defer.returnValue(None) defer.returnValue(None)
defer.returnValue((200, {"status": "not_authenticated"})) defer.returnValue((200, {"status": "not_authenticated"}))
@ -309,7 +312,7 @@ class CasRedirectServlet(ClientV1RestServlet):
"service": "%s?%s" % (hs_redirect_url, client_redirect_url_param) "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
}) })
request.redirect("%s?%s" % (self.cas_server_url, service_param)) request.redirect("%s?%s" % (self.cas_server_url, service_param))
request.finish() finish_request(request)
class CasTicketServlet(ClientV1RestServlet): class CasTicketServlet(ClientV1RestServlet):
@ -362,7 +365,7 @@ class CasTicketServlet(ClientV1RestServlet):
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token) login_token)
request.redirect(redirect_url) request.redirect(redirect_url)
request.finish() finish_request(request)
def add_login_token_to_redirect_url(self, url, token): def add_login_token_to_redirect_url(self, url, token):
url_parts = list(urlparse.urlparse(url)) url_parts = list(urlparse.urlparse(url))
@ -398,16 +401,6 @@ class CasTicketServlet(ClientV1RestServlet):
return (user, attributes) return (user, attributes)
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except ValueError:
raise SynapseError(400, "Content not JSON.")
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server) LoginRestServlet(hs).register(http_server)
if hs.config.saml2_enabled: if hs.config.saml2_enabled:

View File

@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError, Codes
from .base import ClientV1RestServlet, client_path_patterns
import logging
logger = logging.getLogger(__name__)
class LogoutRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/logout$")
def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
def on_OPTIONS(self, request):
return (200, {})
@defer.inlineCallbacks
def on_POST(self, request):
try:
access_token = request.args["access_token"][0]
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
errcode=Codes.MISSING_TOKEN
)
yield self.store.delete_access_token(access_token)
defer.returnValue((200, {}))
class LogoutAllRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/logout/all$")
def __init__(self, hs):
super(LogoutAllRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.auth = hs.get_auth()
def on_OPTIONS(self, request):
return (200, {})
@defer.inlineCallbacks
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
yield self.store.user_delete_access_tokens(user_id)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
LogoutRestServlet(hs).register(http_server)
LogoutAllRestServlet(hs).register(http_server)

View File

@ -17,11 +17,11 @@
""" """
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, AuthError
from synapse.types import UserID from synapse.types import UserID
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,8 +35,15 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
state = yield self.handlers.presence_handler.get_state( if requester.user != user:
target_user=user, auth_user=requester.user) allowed = yield self.handlers.presence_handler.is_visible(
observed_user=user, observer_user=requester.user,
)
if not allowed:
raise AuthError(403, "You are not allowed to see their presence.")
state = yield self.handlers.presence_handler.get_state(target_user=user)
defer.returnValue((200, state)) defer.returnValue((200, state))
@ -45,10 +52,14 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
state = {} if requester.user != user:
try: raise AuthError(403, "Can only set your own presence state")
content = json.loads(request.content.read())
state = {}
content = parse_json_object_from_request(request)
try:
state["presence"] = content.pop("presence") state["presence"] = content.pop("presence")
if "status_msg" in content: if "status_msg" in content:
@ -63,8 +74,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
except: except:
raise SynapseError(400, "Unable to parse state") raise SynapseError(400, "Unable to parse state")
yield self.handlers.presence_handler.set_state( yield self.handlers.presence_handler.set_state(user, state)
target_user=user, auth_user=requester.user, state=state)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -87,11 +97,8 @@ class PresenceListRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Cannot get another user's presence list") raise SynapseError(400, "Cannot get another user's presence list")
presence = yield self.handlers.presence_handler.get_presence_list( presence = yield self.handlers.presence_handler.get_presence_list(
observer_user=user, accepted=True) observer_user=user, accepted=True
)
for p in presence:
observed_user = p.pop("observed_user")
p["user_id"] = observed_user.to_string()
defer.returnValue((200, presence)) defer.returnValue((200, presence))
@ -107,11 +114,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
raise SynapseError( raise SynapseError(
400, "Cannot modify another user's presence list") 400, "Cannot modify another user's presence list")
try: content = parse_json_object_from_request(request)
content = json.loads(request.content.read())
except:
logger.exception("JSON parse error")
raise SynapseError(400, "Unable to parse content")
if "invite" in content: if "invite" in content:
for u in content["invite"]: for u in content["invite"]:

View File

@ -18,8 +18,7 @@ from twisted.internet import defer
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
from synapse.types import UserID from synapse.types import UserID
from synapse.http.servlet import parse_json_object_from_request
import simplejson as json
class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileDisplaynameRestServlet(ClientV1RestServlet):
@ -44,14 +43,15 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
content = parse_json_object_from_request(request)
try: try:
content = json.loads(request.content.read())
new_name = content["displayname"] new_name = content["displayname"]
except: except:
defer.returnValue((400, "Unable to parse name")) defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_displayname( yield self.handlers.profile_handler.set_displayname(
user, requester.user, new_name) user, requester, new_name)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -81,14 +81,14 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
content = parse_json_object_from_request(request)
try: try:
content = json.loads(request.content.read())
new_name = content["avatar_url"] new_name = content["avatar_url"]
except: except:
defer.returnValue((400, "Unable to parse name")) defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_avatar_url( yield self.handlers.profile_handler.set_avatar_url(
user, requester.user, new_name) user, requester, new_name)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -16,19 +16,16 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import (
SynapseError, Codes, UnrecognizedRequestError, NotFoundError, StoreError SynapseError, UnrecognizedRequestError, NotFoundError, StoreError
) )
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
from synapse.storage.push_rule import ( from synapse.storage.push_rule import (
InconsistentRuleException, RuleNotFoundException InconsistentRuleException, RuleNotFoundException
) )
import synapse.push.baserules as baserules from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import ( from synapse.push.baserules import BASE_RULE_IDS
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP from synapse.push.rulekinds import PRIORITY_CLASS_MAP
) from synapse.http.servlet import parse_json_value_from_request
import copy
import simplejson as json
class PushRuleRestServlet(ClientV1RestServlet): class PushRuleRestServlet(ClientV1RestServlet):
@ -36,6 +33,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
"Unrecognised request: You probably wanted a trailing slash") "Unrecognised request: You probably wanted a trailing slash")
def __init__(self, hs):
super(PushRuleRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request): def on_PUT(self, request):
spec = _rule_spec_from_path(request.postpath) spec = _rule_spec_from_path(request.postpath)
@ -49,18 +51,24 @@ class PushRuleRestServlet(ClientV1RestServlet):
if '/' in spec['rule_id'] or '\\' in spec['rule_id']: if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
raise SynapseError(400, "rule_id may not contain slashes") raise SynapseError(400, "rule_id may not contain slashes")
content = _parse_json(request) content = parse_json_value_from_request(request)
user_id = requester.user.to_string()
if 'attr' in spec: if 'attr' in spec:
yield self.set_rule_attr(requester.user.to_string(), spec, content) yield self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
if spec['rule_id'].startswith('.'):
# Rule ids starting with '.' are reserved for server default rules.
raise SynapseError(400, "cannot add new rule_ids that start with '.'")
try: try:
(conditions, actions) = _rule_tuple_from_request_object( (conditions, actions) = _rule_tuple_from_request_object(
spec['template'], spec['template'],
spec['rule_id'], spec['rule_id'],
content, content,
device=spec['device'] if 'device' in spec else None
) )
except InvalidRuleException as e: except InvalidRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, e.message)
@ -74,8 +82,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
after = _namespaced_rule_id(spec, after[0]) after = _namespaced_rule_id(spec, after[0])
try: try:
yield self.hs.get_datastore().add_push_rule( yield self.store.add_push_rule(
user_id=requester.user.to_string(), user_id=user_id,
rule_id=_namespaced_rule_id_from_spec(spec), rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class, priority_class=priority_class,
conditions=conditions, conditions=conditions,
@ -83,6 +91,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
before=before, before=before,
after=after after=after
) )
self.notify_user(user_id)
except InconsistentRuleException as e: except InconsistentRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, e.message)
except RuleNotFoundException as e: except RuleNotFoundException as e:
@ -95,13 +104,15 @@ class PushRuleRestServlet(ClientV1RestServlet):
spec = _rule_spec_from_path(request.postpath) spec = _rule_spec_from_path(request.postpath)
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
try: try:
yield self.hs.get_datastore().delete_push_rule( yield self.store.delete_push_rule(
requester.user.to_string(), namespaced_rule_id user_id, namespaced_rule_id
) )
self.notify_user(user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
except StoreError as e: except StoreError as e:
if e.code == 404: if e.code == 404:
@ -112,74 +123,16 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = requester.user user_id = requester.user.to_string()
# we build up the full structure and then decide which bits of it # we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is # to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference # is probably not going to make a whole lot of difference
rawrules = yield self.hs.get_datastore().get_push_rules_for_user( rawrules = yield self.store.get_push_rules_for_user(user_id)
user.to_string()
)
ruleslist = [] enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
for rawrule in rawrules:
rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy rules = format_push_rules_for_user(requester.user, rawrules, enabled_map)
ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist))
rules = {'global': {}, 'device': {}}
rules['global'] = _add_empty_priority_class_arrays(rules['global'])
enabled_map = yield self.hs.get_datastore().\
get_push_rules_enabled_for_user(user.to_string())
for r in ruleslist:
rulearray = None
template_name = _priority_class_to_template_name(r['priority_class'])
# Remove internal stuff.
for c in r["conditions"]:
c.pop("_id", None)
pattern_type = c.pop("pattern_type", None)
if pattern_type == "user_id":
c["pattern"] = user.to_string()
elif pattern_type == "user_localpart":
c["pattern"] = user.localpart
if r['priority_class'] > PRIORITY_CLASS_MAP['override']:
# per-device rule
profile_tag = _profile_tag_from_conditions(r["conditions"])
r = _strip_device_condition(r)
if not profile_tag:
continue
if profile_tag not in rules['device']:
rules['device'][profile_tag] = {}
rules['device'][profile_tag] = (
_add_empty_priority_class_arrays(
rules['device'][profile_tag]
)
)
rulearray = rules['device'][profile_tag][template_name]
else:
rulearray = rules['global'][template_name]
template_rule = _rule_to_template(r)
if template_rule:
if r['rule_id'] in enabled_map:
template_rule['enabled'] = enabled_map[r['rule_id']]
elif 'enabled' in r:
template_rule['enabled'] = r['enabled']
else:
template_rule['enabled'] = True
rulearray.append(template_rule)
path = request.postpath[1:] path = request.postpath[1:]
@ -195,30 +148,18 @@ class PushRuleRestServlet(ClientV1RestServlet):
path = path[1:] path = path[1:]
result = _filter_ruleset_with_path(rules['global'], path) result = _filter_ruleset_with_path(rules['global'], path)
defer.returnValue((200, result)) defer.returnValue((200, result))
elif path[0] == 'device':
path = path[1:]
if path == []:
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
defer.returnValue((200, rules['device']))
profile_tag = path[0]
path = path[1:]
if profile_tag not in rules['device']:
ret = {}
ret = _add_empty_priority_class_arrays(ret)
defer.returnValue((200, ret))
ruleset = rules['device'][profile_tag]
result = _filter_ruleset_with_path(ruleset, path)
defer.returnValue((200, result))
else: else:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
def notify_user(self, user_id):
stream_id, _ = self.store.get_push_rules_stream_token()
self.notifier.on_new_event(
"push_rules_key", stream_id, users=[user_id]
)
def set_rule_attr(self, user_id, spec, val): def set_rule_attr(self, user_id, spec, val):
if spec['attr'] == 'enabled': if spec['attr'] == 'enabled':
if isinstance(val, dict) and "enabled" in val: if isinstance(val, dict) and "enabled" in val:
@ -229,16 +170,20 @@ class PushRuleRestServlet(ClientV1RestServlet):
# bools directly, so let's not break them. # bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean") raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
return self.hs.get_datastore().set_push_rule_enabled( return self.store.set_push_rule_enabled(
user_id, namespaced_rule_id, val user_id, namespaced_rule_id, val
) )
else: elif spec['attr'] == 'actions':
raise UnrecognizedRequestError() actions = val.get('actions')
_check_actions(actions)
def get_rule_attr(self, user_id, namespaced_rule_id, attr): namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
if attr == 'enabled': rule_id = spec['rule_id']
return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id( is_default_rule = rule_id.startswith(".")
user_id, namespaced_rule_id if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
return self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule
) )
else: else:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
@ -252,16 +197,9 @@ def _rule_spec_from_path(path):
scope = path[1] scope = path[1]
path = path[2:] path = path[2:]
if scope not in ['global', 'device']: if scope != 'global':
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
device = None
if scope == 'device':
if len(path) == 0:
raise UnrecognizedRequestError()
device = path[0]
path = path[1:]
if len(path) == 0: if len(path) == 0:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
@ -278,8 +216,6 @@ def _rule_spec_from_path(path):
'template': template, 'template': template,
'rule_id': rule_id 'rule_id': rule_id
} }
if device:
spec['profile_tag'] = device
path = path[1:] path = path[1:]
@ -289,7 +225,7 @@ def _rule_spec_from_path(path):
return spec return spec
def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None): def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
if rule_template in ['override', 'underride']: if rule_template in ['override', 'underride']:
if 'conditions' not in req_obj: if 'conditions' not in req_obj:
raise InvalidRuleException("Missing 'conditions'") raise InvalidRuleException("Missing 'conditions'")
@ -322,16 +258,19 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None
else: else:
raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) raise InvalidRuleException("Unknown rule template: %s" % (rule_template,))
if device:
conditions.append({
'kind': 'device',
'profile_tag': device
})
if 'actions' not in req_obj: if 'actions' not in req_obj:
raise InvalidRuleException("No actions found") raise InvalidRuleException("No actions found")
actions = req_obj['actions'] actions = req_obj['actions']
_check_actions(actions)
return conditions, actions
def _check_actions(actions):
if not isinstance(actions, list):
raise InvalidRuleException("No actions found")
for a in actions: for a in actions:
if a in ['notify', 'dont_notify', 'coalesce']: if a in ['notify', 'dont_notify', 'coalesce']:
pass pass
@ -340,25 +279,6 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None
else: else:
raise InvalidRuleException("Unrecognised action") raise InvalidRuleException("Unrecognised action")
return conditions, actions
def _add_empty_priority_class_arrays(d):
for pc in PRIORITY_CLASS_MAP.keys():
d[pc] = []
return d
def _profile_tag_from_conditions(conditions):
"""
Given a list of conditions, return the profile tag of the
device rule if there is one
"""
for c in conditions:
if c['kind'] == 'device':
return c['profile_tag']
return None
def _filter_ruleset_with_path(ruleset, path): def _filter_ruleset_with_path(ruleset, path):
if path == []: if path == []:
@ -393,93 +313,32 @@ def _filter_ruleset_with_path(ruleset, path):
attr = path[0] attr = path[0]
if attr in the_rule: if attr in the_rule:
return the_rule[attr] # Make sure we return a JSON object as the attribute may be a
# JSON value.
return {attr: the_rule[attr]}
else: else:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
def _priority_class_from_spec(spec): def _priority_class_from_spec(spec):
if spec['template'] not in PRIORITY_CLASS_MAP.keys(): if spec['template'] not in PRIORITY_CLASS_MAP.keys():
raise InvalidRuleException("Unknown template: %s" % (spec['kind'])) raise InvalidRuleException("Unknown template: %s" % (spec['template']))
pc = PRIORITY_CLASS_MAP[spec['template']] pc = PRIORITY_CLASS_MAP[spec['template']]
if spec['scope'] == 'device':
pc += len(PRIORITY_CLASS_MAP)
return pc return pc
def _priority_class_to_template_name(pc):
if pc > PRIORITY_CLASS_MAP['override']:
# per-device
prio_class_index = pc - len(PRIORITY_CLASS_MAP)
return PRIORITY_CLASS_INVERSE_MAP[prio_class_index]
else:
return PRIORITY_CLASS_INVERSE_MAP[pc]
def _rule_to_template(rule):
unscoped_rule_id = None
if 'rule_id' in rule:
unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id'])
template_name = _priority_class_to_template_name(rule['priority_class'])
if template_name in ['override', 'underride']:
templaterule = {k: rule[k] for k in ["conditions", "actions"]}
elif template_name in ["sender", "room"]:
templaterule = {'actions': rule['actions']}
unscoped_rule_id = rule['conditions'][0]['pattern']
elif template_name == 'content':
if len(rule["conditions"]) != 1:
return None
thecond = rule["conditions"][0]
if "pattern" not in thecond:
return None
templaterule = {'actions': rule['actions']}
templaterule["pattern"] = thecond["pattern"]
if unscoped_rule_id:
templaterule['rule_id'] = unscoped_rule_id
if 'default' in rule:
templaterule['default'] = rule['default']
return templaterule
def _strip_device_condition(rule):
for i, c in enumerate(rule['conditions']):
if c['kind'] == 'device':
del rule['conditions'][i]
return rule
def _namespaced_rule_id_from_spec(spec): def _namespaced_rule_id_from_spec(spec):
return _namespaced_rule_id(spec, spec['rule_id']) return _namespaced_rule_id(spec, spec['rule_id'])
def _namespaced_rule_id(spec, rule_id): def _namespaced_rule_id(spec, rule_id):
if spec['scope'] == 'global': return "global/%s/%s" % (spec['template'], rule_id)
scope = 'global'
else:
scope = 'device/%s' % (spec['profile_tag'])
return "%s/%s/%s" % (scope, spec['template'], rule_id)
def _rule_id_from_namespaced(in_rule_id):
return in_rule_id.split('/')[-1]
class InvalidRuleException(Exception): class InvalidRuleException(Exception):
pass pass
# XXX: C+ped from rest/room.py - surely this should be common?
def _parse_json(request):
try:
content = json.loads(request.content.read())
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
PushRuleRestServlet(hs).register(http_server) PushRuleRestServlet(hs).register(http_server)

View File

@ -17,9 +17,10 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,12 +29,16 @@ logger = logging.getLogger(__name__)
class PusherRestServlet(ClientV1RestServlet): class PusherRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/pushers/set$") PATTERNS = client_path_patterns("/pushers/set$")
def __init__(self, hs):
super(PusherRestServlet, self).__init__(hs)
self.notifier = hs.get_notifier()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = requester.user user = requester.user
content = _parse_json(request) content = parse_json_object_from_request(request)
pusher_pool = self.hs.get_pusherpool() pusher_pool = self.hs.get_pusherpool()
@ -45,7 +50,7 @@ class PusherRestServlet(ClientV1RestServlet):
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
reqd = ['profile_tag', 'kind', 'app_id', 'app_display_name', reqd = ['kind', 'app_id', 'app_display_name',
'device_display_name', 'pushkey', 'lang', 'data'] 'device_display_name', 'pushkey', 'lang', 'data']
missing = [] missing = []
for i in reqd: for i in reqd:
@ -73,36 +78,26 @@ class PusherRestServlet(ClientV1RestServlet):
yield pusher_pool.add_pusher( yield pusher_pool.add_pusher(
user_id=user.to_string(), user_id=user.to_string(),
access_token=requester.access_token_id, access_token=requester.access_token_id,
profile_tag=content['profile_tag'],
kind=content['kind'], kind=content['kind'],
app_id=content['app_id'], app_id=content['app_id'],
app_display_name=content['app_display_name'], app_display_name=content['app_display_name'],
device_display_name=content['device_display_name'], device_display_name=content['device_display_name'],
pushkey=content['pushkey'], pushkey=content['pushkey'],
lang=content['lang'], lang=content['lang'],
data=content['data'] data=content['data'],
profile_tag=content.get('profile_tag', ""),
) )
except PusherConfigException as pce: except PusherConfigException as pce:
raise SynapseError(400, "Config Error: " + pce.message, raise SynapseError(400, "Config Error: " + pce.message,
errcode=Codes.MISSING_PARAM) errcode=Codes.MISSING_PARAM)
self.notifier.on_new_replication_data()
defer.returnValue((200, {})) defer.returnValue((200, {}))
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
# XXX: C+ped from rest/room.py - surely this should be common?
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
PusherRestServlet(hs).register(http_server) PusherRestServlet(hs).register(http_server)

View File

@ -18,14 +18,14 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.http.servlet import parse_json_object_from_request
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from hashlib import sha1 from hashlib import sha1
import hmac import hmac
import simplejson as json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -98,7 +98,7 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
register_json = _parse_json(request) register_json = parse_json_object_from_request(request)
session = (register_json["session"] session = (register_json["session"]
if "session" in register_json else None) if "session" in register_json else None)
@ -355,15 +355,5 @@ class RegisterRestServlet(ClientV1RestServlet):
) )
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except ValueError:
raise SynapseError(400, "Content not JSON.")
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server)

View File

@ -16,14 +16,14 @@
""" This module contains REST servlets to do with rooms: /rooms/<paths> """ """ This module contains REST servlets to do with rooms: /rooms/<paths> """
from twisted.internet import defer from twisted.internet import defer
from base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
from synapse.api.errors import SynapseError, Codes, AuthError from synapse.api.errors import SynapseError, Codes, AuthError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.http.servlet import parse_json_object_from_request
import simplejson as json
import logging import logging
import urllib import urllib
@ -63,35 +63,18 @@ class RoomCreateRestServlet(ClientV1RestServlet):
def on_POST(self, request): def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
room_config = self.get_room_config(request)
info = yield self.make_room(
room_config,
requester.user,
None,
)
room_config.update(info)
defer.returnValue((200, info))
@defer.inlineCallbacks
def make_room(self, room_config, auth_user, room_id):
handler = self.handlers.room_creation_handler handler = self.handlers.room_creation_handler
info = yield handler.create_room( info = yield handler.create_room(
user_id=auth_user.to_string(), requester, self.get_room_config(request)
room_id=room_id,
config=room_config
) )
defer.returnValue(info)
defer.returnValue((200, info))
def get_room_config(self, request): def get_room_config(self, request):
try: user_supplied_config = parse_json_object_from_request(request)
user_supplied_config = json.loads(request.content.read()) # default visibility
if "visibility" not in user_supplied_config: user_supplied_config.setdefault("visibility", "public")
# default visibility return user_supplied_config
user_supplied_config["visibility"] = "public"
return user_supplied_config
except (ValueError, TypeError):
raise SynapseError(400, "Body must be JSON.",
errcode=Codes.BAD_JSON)
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return (200, {}) return (200, {})
@ -149,7 +132,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = parse_json_object_from_request(request)
event_dict = { event_dict = {
"type": event_type, "type": event_type,
@ -162,11 +145,22 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
event_dict["state_key"] = state_key event_dict["state_key"] = state_key
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_event( event, context = yield msg_handler.create_event(
event_dict, token_id=requester.access_token_id, txn_id=txn_id, event_dict,
token_id=requester.access_token_id,
txn_id=txn_id,
) )
defer.returnValue((200, {})) if event_type == EventTypes.Member:
yield self.handlers.room_member_handler.send_membership_event(
requester,
event,
context,
)
else:
yield msg_handler.send_nonmember_event(requester, event, context)
defer.returnValue((200, {"event_id": event.event_id}))
# TODO: Needs unit testing for generic events + feedback # TODO: Needs unit testing for generic events + feedback
@ -180,17 +174,17 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_type, txn_id=None): def on_POST(self, request, room_id, event_type, txn_id=None):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = _parse_json(request) content = parse_json_object_from_request(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_event( event = yield msg_handler.create_and_send_nonmember_event(
requester,
{ {
"type": event_type, "type": event_type,
"content": content, "content": content,
"room_id": room_id, "room_id": room_id,
"sender": requester.user.to_string(), "sender": requester.user.to_string(),
}, },
token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
) )
@ -229,46 +223,37 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
allow_guest=True, allow_guest=True,
) )
# the identifier could be a room alias or a room id. Try one then the
# other if it fails to parse, without swallowing other valid
# SynapseErrors.
identifier = None
is_room_alias = False
try: try:
identifier = RoomAlias.from_string(room_identifier) content = parse_json_object_from_request(request)
is_room_alias = True except:
except SynapseError: # Turns out we used to ignore the body entirely, and some clients
identifier = RoomID.from_string(room_identifier) # cheekily send invalid bodies.
content = {}
# TODO: Support for specifying the home server to join with? if RoomID.is_valid(room_identifier):
room_id = room_identifier
if is_room_alias: remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
handler = self.handlers.room_member_handler handler = self.handlers.room_member_handler
ret_dict = yield handler.join_room_alias( room_alias = RoomAlias.from_string(room_identifier)
requester.user, room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
identifier, room_id = room_id.to_string()
) else:
defer.returnValue((200, ret_dict)) raise SynapseError(400, "%s was not legal room ID or room alias" % (
else: # room id room_identifier,
msg_handler = self.handlers.message_handler ))
content = {"membership": Membership.JOIN}
if requester.is_guest:
content["kind"] = "guest"
yield msg_handler.create_and_send_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": identifier.to_string(),
"sender": requester.user.to_string(),
"state_key": requester.user.to_string(),
},
token_id=requester.access_token_id,
txn_id=txn_id,
is_guest=requester.is_guest,
)
defer.returnValue((200, {"room_id": identifier.to_string()})) yield self.handlers.room_member_handler.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
action="join",
txn_id=txn_id,
remote_room_hosts=remote_room_hosts,
third_party_signed=content.get("third_party_signed", None),
)
defer.returnValue((200, {"room_id": room_id}))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_identifier, txn_id): def on_PUT(self, request, room_identifier, txn_id):
@ -316,18 +301,6 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
if event["type"] != EventTypes.Member: if event["type"] != EventTypes.Member:
continue continue
chunk.append(event) chunk.append(event)
# FIXME: should probably be state_key here, not user_id
target_user = UserID.from_string(event["user_id"])
# Presence is an optional cache; don't fail if we can't fetch it
try:
presence_handler = self.handlers.presence_handler
presence_state = yield presence_handler.get_state(
target_user=target_user,
auth_user=requester.user,
)
event["content"].update(presence_state)
except:
pass
defer.returnValue((200, { defer.returnValue((200, {
"chunk": chunk "chunk": chunk
@ -454,7 +427,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
}: }:
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
content = _parse_json(request) try:
content = parse_json_object_from_request(request)
except:
# Turns out we used to ignore the body entirely, and some clients
# cheekily send invalid bodies.
content = {}
if membership_action == "invite" and self._has_3pid_invite_keys(content): if membership_action == "invite" and self._has_3pid_invite_keys(content):
yield self.handlers.room_member_handler.do_3pid_invite( yield self.handlers.room_member_handler.do_3pid_invite(
@ -463,7 +441,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
content["medium"], content["medium"],
content["address"], content["address"],
content["id_server"], content["id_server"],
requester.access_token_id, requester,
txn_id txn_id
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -481,6 +459,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
room_id=room_id, room_id=room_id,
action=membership_action, action=membership_action,
txn_id=txn_id, txn_id=txn_id,
third_party_signed=content.get("third_party_signed", None),
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -516,10 +495,11 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_id, txn_id=None): def on_POST(self, request, room_id, event_id, txn_id=None):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = parse_json_object_from_request(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_event( event = yield msg_handler.create_and_send_nonmember_event(
requester,
{ {
"type": EventTypes.Redaction, "type": EventTypes.Redaction,
"content": content, "content": content,
@ -527,7 +507,6 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
"sender": requester.user.to_string(), "sender": requester.user.to_string(),
"redacts": event_id, "redacts": event_id,
}, },
token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
) )
@ -553,6 +532,10 @@ class RoomTypingRestServlet(ClientV1RestServlet):
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$"
) )
def __init__(self, hs):
super(RoomTypingRestServlet, self).__init__(hs)
self.presence_handler = hs.get_handlers().presence_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id): def on_PUT(self, request, room_id, user_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
@ -560,10 +543,12 @@ class RoomTypingRestServlet(ClientV1RestServlet):
room_id = urllib.unquote(room_id) room_id = urllib.unquote(room_id)
target_user = UserID.from_string(urllib.unquote(user_id)) target_user = UserID.from_string(urllib.unquote(user_id))
content = _parse_json(request) content = parse_json_object_from_request(request)
typing_handler = self.handlers.typing_notification_handler typing_handler = self.handlers.typing_notification_handler
yield self.presence_handler.bump_presence_active_time(requester.user)
if content["typing"]: if content["typing"]:
yield typing_handler.started_typing( yield typing_handler.started_typing(
target_user=target_user, target_user=target_user,
@ -590,7 +575,7 @@ class SearchRestServlet(ClientV1RestServlet):
def on_POST(self, request): def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = parse_json_object_from_request(request)
batch = request.args.get("next_batch", [None])[0] batch = request.args.get("next_batch", [None])[0]
results = yield self.handlers.search_handler.search( results = yield self.handlers.search_handler.search(
@ -602,17 +587,6 @@ class SearchRestServlet(ClientV1RestServlet):
defer.returnValue((200, results)) defer.returnValue((200, results))
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_txn_path(servlet, regex_string, http_server, with_get=False): def register_txn_path(servlet, regex_string, http_server, with_get=False):
"""Registers a transaction-based path. """Registers a transaction-based path.

View File

@ -15,7 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import hmac import hmac

View File

@ -17,11 +17,9 @@
""" """
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
from synapse.api.errors import SynapseError
import re import re
import logging import logging
import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,23 +42,3 @@ def client_v2_patterns(path_regex, releases=(0,)):
new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
patterns.append(re.compile("^" + new_prefix + path_regex)) patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns return patterns
def parse_request_allow_empty(request):
content = request.content.read()
if content is None or content == '':
return None
try:
return simplejson.loads(content)
except simplejson.JSONDecodeError:
raise SynapseError(400, "Content not JSON.")
def parse_json_dict_from_request(request):
try:
content = simplejson.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except simplejson.JSONDecodeError:
raise SynapseError(400, "Content not JSON.")

View File

@ -17,10 +17,10 @@ from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import LoginError, SynapseError, Codes from synapse.api.errors import LoginError, SynapseError, Codes
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from ._base import client_v2_patterns, parse_json_dict_from_request from ._base import client_v2_patterns
import logging import logging
@ -41,9 +41,9 @@ class PasswordRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
yield run_on_reactor() yield run_on_reactor()
body = parse_json_dict_from_request(request) body = parse_json_object_from_request(request)
authed, result, params = yield self.auth_handler.check_auth([ authed, result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD], [LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY] [LoginType.EMAIL_IDENTITY]
], body, self.hs.get_ip_from_request(request)) ], body, self.hs.get_ip_from_request(request))
@ -79,7 +79,7 @@ class PasswordRestServlet(RestServlet):
new_password = params['new_password'] new_password = params['new_password']
yield self.auth_handler.set_password( yield self.auth_handler.set_password(
user_id, new_password user_id, new_password, requester
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -114,7 +114,7 @@ class ThreepidRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
yield run_on_reactor() yield run_on_reactor()
body = parse_json_dict_from_request(request) body = parse_json_object_from_request(request)
threePidCreds = body.get('threePidCreds') threePidCreds = body.get('threePidCreds')
threePidCreds = body.get('three_pid_creds', threePidCreds) threePidCreds = body.get('three_pid_creds', threePidCreds)

View File

@ -15,15 +15,13 @@
from ._base import client_v2_patterns from ._base import client_v2_patterns
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError
from twisted.internet import defer from twisted.internet import defer
import logging import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,11 +45,7 @@ class AccountDataServlet(RestServlet):
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.") raise AuthError(403, "Cannot add account data for other users.")
try: body = parse_json_object_from_request(request)
content_bytes = request.content.read()
body = json.loads(content_bytes)
except:
raise SynapseError(400, "Invalid JSON")
max_id = yield self.store.add_account_data_for_user( max_id = yield self.store.add_account_data_for_user(
user_id, account_data_type, body user_id, account_data_type, body
@ -86,14 +80,7 @@ class RoomAccountDataServlet(RestServlet):
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.") raise AuthError(403, "Cannot add account data for other users.")
try: body = parse_json_object_from_request(request)
content_bytes = request.content.read()
body = json.loads(content_bytes)
except:
raise SynapseError(400, "Invalid JSON")
if not isinstance(body, dict):
raise ValueError("Expected a JSON object")
max_id = yield self.store.add_account_data_to_room( max_id = yield self.store.add_account_data_to_room(
user_id, room_id, account_data_type, body user_id, room_id, account_data_type, body

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_v2_patterns
@ -130,7 +131,7 @@ class AuthRestServlet(RestServlet):
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes) request.write(html_bytes)
request.finish() finish_request(request)
defer.returnValue(None) defer.returnValue(None)
else: else:
raise SynapseError(404, "Unknown auth stage type") raise SynapseError(404, "Unknown auth stage type")
@ -176,7 +177,7 @@ class AuthRestServlet(RestServlet):
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes) request.write(html_bytes)
request.finish() finish_request(request)
defer.returnValue(None) defer.returnValue(None)
else: else:

View File

@ -16,12 +16,11 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID from synapse.types import UserID
from ._base import client_v2_patterns from ._base import client_v2_patterns
import simplejson as json
import logging import logging
@ -84,12 +83,7 @@ class CreateFilterRestServlet(RestServlet):
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only create filters for local users") raise SynapseError(400, "Can only create filters for local users")
try: content = parse_json_object_from_request(request)
content = json.loads(request.content.read())
# TODO(paul): check for required keys and invalid keys
except:
raise SynapseError(400, "Invalid filter definition")
filter_id = yield self.filtering.add_user_filter( filter_id = yield self.filtering.add_user_filter(
user_localpart=target_user.localpart, user_localpart=target_user.localpart,

View File

@ -15,16 +15,15 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.servlet import RestServlet
from synapse.types import UserID from synapse.types import UserID
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from ._base import client_v2_patterns from ._base import client_v2_patterns
import simplejson as json
import logging import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,10 +67,9 @@ class KeyUploadServlet(RestServlet):
user_id = requester.user.to_string() user_id = requester.user.to_string()
# TODO: Check that the device_id matches that in the authentication # TODO: Check that the device_id matches that in the authentication
# or derive the device_id from the authentication instead. # or derive the device_id from the authentication instead.
try:
body = json.loads(request.content.read()) body = parse_json_object_from_request(request)
except:
raise SynapseError(400, "Invalid key JSON")
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys. # TODO: Validate the JSON to make sure it has the right keys.
@ -173,10 +171,7 @@ class KeyQueryServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id, device_id): def on_POST(self, request, user_id, device_id):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
try: body = parse_json_object_from_request(request)
body = json.loads(request.content.read())
except:
raise SynapseError(400, "Invalid key JSON")
result = yield self.handle_request(body) result = yield self.handle_request(body)
defer.returnValue(result) defer.returnValue(result)
@ -272,10 +267,7 @@ class OneTimeKeyServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm): def on_POST(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
try: body = parse_json_object_from_request(request)
body = json.loads(request.content.read())
except:
raise SynapseError(400, "Invalid key JSON")
result = yield self.handle_request(body) result = yield self.handle_request(body)
defer.returnValue(result) defer.returnValue(result)

View File

@ -37,6 +37,7 @@ class ReceiptRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.receipts_handler = hs.get_handlers().receipts_handler self.receipts_handler = hs.get_handlers().receipts_handler
self.presence_handler = hs.get_handlers().presence_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, receipt_type, event_id): def on_POST(self, request, room_id, receipt_type, event_id):
@ -45,6 +46,8 @@ class ReceiptRestServlet(RestServlet):
if receipt_type != "m.read": if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'") raise SynapseError(400, "Receipt type must be 'm.read'")
yield self.presence_handler.bump_presence_active_time(requester.user)
yield self.receipts_handler.received_client_receipt( yield self.receipts_handler.received_client_receipt(
room_id, room_id,
receipt_type, receipt_type,

View File

@ -17,9 +17,9 @@ from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns, parse_json_dict_from_request from ._base import client_v2_patterns
import logging import logging
import hmac import hmac
@ -73,7 +73,7 @@ class RegisterRestServlet(RestServlet):
ret = yield self.onEmailTokenRequest(request) ret = yield self.onEmailTokenRequest(request)
defer.returnValue(ret) defer.returnValue(ret)
body = parse_json_dict_from_request(request) body = parse_json_object_from_request(request)
# we do basic sanity checks here because the auth layer will store these # we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us. # in sessions. Pull out the username/password provided to us.
@ -122,10 +122,22 @@ class RegisterRestServlet(RestServlet):
guest_access_token = body.get("guest_access_token", None) guest_access_token = body.get("guest_access_token", None)
session_id = self.auth_handler.get_session_id(body)
registered_user_id = None
if session_id:
# if we get a registered user id out of here, it means we previously
# registered a user for this session, so we could just return the
# user here. We carry on and go through the auth checks though,
# for paranoia.
registered_user_id = self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
if desired_username is not None: if desired_username is not None:
yield self.registration_handler.check_username( yield self.registration_handler.check_username(
desired_username, desired_username,
guest_access_token=guest_access_token guest_access_token=guest_access_token,
assigned_user_id=registered_user_id,
) )
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
@ -139,7 +151,7 @@ class RegisterRestServlet(RestServlet):
[LoginType.EMAIL_IDENTITY] [LoginType.EMAIL_IDENTITY]
] ]
authed, result, params = yield self.auth_handler.check_auth( authed, result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request) flows, body, self.hs.get_ip_from_request(request)
) )
@ -147,6 +159,22 @@ class RegisterRestServlet(RestServlet):
defer.returnValue((401, result)) defer.returnValue((401, result))
return return
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session",
registered_user_id
)
access_token = yield self.auth_handler.issue_access_token(registered_user_id)
refresh_token = yield self.auth_handler.issue_refresh_token(
registered_user_id
)
defer.returnValue((200, {
"user_id": registered_user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"refresh_token": refresh_token,
}))
# NB: This may be from the auth handler and NOT from the POST # NB: This may be from the auth handler and NOT from the POST
if 'password' not in params: if 'password' not in params:
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM) raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
@ -161,6 +189,12 @@ class RegisterRestServlet(RestServlet):
guest_access_token=guest_access_token, guest_access_token=guest_access_token,
) )
# remember that we've now registered that user account, and with what
# user ID (since the user may not have specified)
self.auth_handler.set_session_data(
session_id, "registered_user_id", user_id
)
if result and LoginType.EMAIL_IDENTITY in result: if result and LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY] threepid = result[LoginType.EMAIL_IDENTITY]
@ -187,7 +221,7 @@ class RegisterRestServlet(RestServlet):
else: else:
logger.info("bind_email not specified: not binding email") logger.info("bind_email not specified: not binding email")
result = self._create_registration_details(user_id, token) result = yield self._create_registration_details(user_id, token)
defer.returnValue((200, result)) defer.returnValue((200, result))
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
@ -198,7 +232,7 @@ class RegisterRestServlet(RestServlet):
(user_id, token) = yield self.registration_handler.appservice_register( (user_id, token) = yield self.registration_handler.appservice_register(
username, as_token username, as_token
) )
defer.returnValue(self._create_registration_details(user_id, token)) defer.returnValue((yield self._create_registration_details(user_id, token)))
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_shared_secret_registration(self, username, password, mac): def _do_shared_secret_registration(self, username, password, mac):
@ -225,18 +259,21 @@ class RegisterRestServlet(RestServlet):
(user_id, token) = yield self.registration_handler.register( (user_id, token) = yield self.registration_handler.register(
localpart=username, password=password localpart=username, password=password
) )
defer.returnValue(self._create_registration_details(user_id, token)) defer.returnValue((yield self._create_registration_details(user_id, token)))
@defer.inlineCallbacks
def _create_registration_details(self, user_id, token): def _create_registration_details(self, user_id, token):
return { refresh_token = yield self.auth_handler.issue_refresh_token(user_id)
defer.returnValue({
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} "refresh_token": refresh_token,
})
@defer.inlineCallbacks @defer.inlineCallbacks
def onEmailTokenRequest(self, request): def onEmailTokenRequest(self, request):
body = parse_json_dict_from_request(request) body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt'] required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = [] absent = []

View File

@ -25,6 +25,7 @@ from synapse.events.utils import (
) )
from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.constants import PresenceState
from ._base import client_v2_patterns from ._base import client_v2_patterns
import copy import copy
@ -82,6 +83,7 @@ class SyncRestServlet(RestServlet):
self.sync_handler = hs.get_handlers().sync_handler self.sync_handler = hs.get_handlers().sync_handler
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
self.presence_handler = hs.get_handlers().presence_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -139,17 +141,19 @@ class SyncRestServlet(RestServlet):
else: else:
since_token = None since_token = None
if set_presence == "online": affect_presence = set_presence != PresenceState.OFFLINE
yield self.event_stream_handler.started_stream(user)
try: if affect_presence:
yield self.presence_handler.set_state(user, {"presence": set_presence})
context = yield self.presence_handler.user_syncing(
user.to_string(), affect_presence=affect_presence,
)
with context:
sync_result = yield self.sync_handler.wait_for_sync_for_user( sync_result = yield self.sync_handler.wait_for_sync_for_user(
sync_config, since_token=since_token, timeout=timeout, sync_config, since_token=since_token, timeout=timeout,
full_state=full_state full_state=full_state
) )
finally:
if set_presence == "online":
self.event_stream_handler.stopped_stream(user)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View File

@ -15,15 +15,13 @@
from ._base import client_v2_patterns from ._base import client_v2_patterns
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError
from twisted.internet import defer from twisted.internet import defer
import logging import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -72,11 +70,7 @@ class TagServlet(RestServlet):
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.") raise AuthError(403, "Cannot add tags for other users.")
try: body = parse_json_object_from_request(request)
content_bytes = request.content.read()
body = json.loads(content_bytes)
except:
raise SynapseError(400, "Invalid tag JSON")
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)

View File

@ -16,9 +16,9 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns, parse_json_dict_from_request from ._base import client_v2_patterns
class TokenRefreshRestServlet(RestServlet): class TokenRefreshRestServlet(RestServlet):
@ -35,7 +35,7 @@ class TokenRefreshRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
body = parse_json_dict_from_request(request) body = parse_json_object_from_request(request)
try: try:
old_refresh_token = body["refresh_token"] old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_handlers().auth_handler auth_handler = self.hs.get_handlers().auth_handler

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from synapse.http.server import request_handler, respond_with_json_bytes from synapse.http.server import request_handler, respond_with_json_bytes
from synapse.http.servlet import parse_integer from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from twisted.web.resource import Resource from twisted.web.resource import Resource
@ -22,7 +22,6 @@ from twisted.internet import defer
from io import BytesIO from io import BytesIO
import json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -126,14 +125,7 @@ class RemoteKey(Resource):
@request_handler @request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def async_render_POST(self, request): def async_render_POST(self, request):
try: content = parse_json_object_from_request(request)
content = json.loads(request.content.read())
if type(content) != dict:
raise ValueError()
except ValueError:
raise SynapseError(
400, "Content must be JSON object.", errcode=Codes.NOT_JSON
)
query = content["server_keys"] query = content["server_keys"]

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.http.server import respond_with_json_bytes from synapse.http.server import respond_with_json_bytes, finish_request
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.api.errors import ( from synapse.api.errors import (
@ -144,7 +144,7 @@ class ContentRepoResource(resource.Resource):
# after the file has been sent, clean up and finish the request # after the file has been sent, clean up and finish the request
def cbFinished(ignored): def cbFinished(ignored):
f.close() f.close()
request.finish() finish_request(request)
d.addCallback(cbFinished) d.addCallback(cbFinished)
else: else:
respond_with_json_bytes( respond_with_json_bytes(

View File

@ -16,7 +16,7 @@
from .thumbnailer import Thumbnailer from .thumbnailer import Thumbnailer
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.http.server import respond_with_json from synapse.http.server import respond_with_json, finish_request
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.api.errors import ( from synapse.api.errors import (
cs_error, Codes, SynapseError cs_error, Codes, SynapseError
@ -238,7 +238,7 @@ class BaseMediaResource(Resource):
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
yield FileSender().beginFileTransfer(f, request) yield FileSender().beginFileTransfer(f, request)
request.finish() finish_request(request)
else: else:
self._respond_404(request) self._respond_404(request)

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes from synapse.api.auth import AuthEventTypes
@ -27,6 +28,7 @@ from collections import namedtuple
import logging import logging
import hashlib import hashlib
import os
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,8 +36,11 @@ logger = logging.getLogger(__name__)
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
SIZE_OF_CACHE = 1000 CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
EVICTION_TIMEOUT_SECONDS = 20
SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR)
EVICTION_TIMEOUT_SECONDS = 60 * 60
class _StateCacheEntry(object): class _StateCacheEntry(object):
@ -85,16 +90,8 @@ class StateHandler(object):
""" """
event_ids = yield self.store.get_latest_event_ids_in_room(room_id) event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
cache = None res = yield self.resolve_state_groups(room_id, event_ids)
if self._state_cache is not None: state = res[1]
cache = self._state_cache.get(frozenset(event_ids), None)
if cache:
cache.ts = self.clock.time_msec()
state = cache.state
else:
res = yield self.resolve_state_groups(room_id, event_ids)
state = res[1]
if event_type: if event_type:
defer.returnValue(state.get((event_type, state_key))) defer.returnValue(state.get((event_type, state_key)))
@ -186,20 +183,6 @@ class StateHandler(object):
""" """
logger.debug("resolve_state_groups event_ids %s", event_ids) logger.debug("resolve_state_groups event_ids %s", event_ids)
if self._state_cache is not None:
cache = self._state_cache.get(frozenset(event_ids), None)
if cache and cache.state_group:
cache.ts = self.clock.time_msec()
prev_state = cache.state.get((event_type, state_key), None)
if prev_state:
prev_state = prev_state.event_id
prev_states = [prev_state]
else:
prev_states = []
defer.returnValue(
(cache.state_group, cache.state, prev_states)
)
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
room_id, event_ids room_id, event_ids
) )
@ -209,7 +192,7 @@ class StateHandler(object):
state_groups.keys() state_groups.keys()
) )
group_names = set(state_groups.keys()) group_names = frozenset(state_groups.keys())
if len(group_names) == 1: if len(group_names) == 1:
name, state_list = state_groups.items().pop() name, state_list = state_groups.items().pop()
state = { state = {
@ -223,29 +206,38 @@ class StateHandler(object):
else: else:
prev_states = [] prev_states = []
if self._state_cache is not None:
cache = _StateCacheEntry(
state=state,
state_group=name,
ts=self.clock.time_msec()
)
self._state_cache[frozenset(event_ids)] = cache
defer.returnValue((name, state, prev_states)) defer.returnValue((name, state, prev_states))
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
if cache and cache.state_group:
cache.ts = self.clock.time_msec()
event_dict = yield self.store.get_events(cache.state.values())
state = {(e.type, e.state_key): e for e in event_dict.values()}
prev_state = state.get((event_type, state_key), None)
if prev_state:
prev_state = prev_state.event_id
prev_states = [prev_state]
else:
prev_states = []
defer.returnValue(
(cache.state_group, state, prev_states)
)
new_state, prev_states = self._resolve_events( new_state, prev_states = self._resolve_events(
state_groups.values(), event_type, state_key state_groups.values(), event_type, state_key
) )
if self._state_cache is not None: if self._state_cache is not None:
cache = _StateCacheEntry( cache = _StateCacheEntry(
state=new_state, state={key: event.event_id for key, event in new_state.items()},
state_group=None, state_group=None,
ts=self.clock.time_msec() ts=self.clock.time_msec()
) )
self._state_cache[frozenset(event_ids)] = cache self._state_cache[group_names] = cache
defer.returnValue((None, new_state, prev_states)) defer.returnValue((None, new_state, prev_states))
@ -263,48 +255,49 @@ class StateHandler(object):
from (type, state_key) to event. prev_states is a list of event_ids. from (type, state_key) to event. prev_states is a list of event_ids.
:rtype: (dict[(str, str), synapse.events.FrozenEvent], list[str]) :rtype: (dict[(str, str), synapse.events.FrozenEvent], list[str])
""" """
state = {} with Measure(self.clock, "state._resolve_events"):
for st in state_sets: state = {}
for e in st: for st in state_sets:
state.setdefault( for e in st:
(e.type, e.state_key), state.setdefault(
{} (e.type, e.state_key),
)[e.event_id] = e {}
)[e.event_id] = e
unconflicted_state = { unconflicted_state = {
k: v.values()[0] for k, v in state.items() k: v.values()[0] for k, v in state.items()
if len(v.values()) == 1 if len(v.values()) == 1
} }
conflicted_state = { conflicted_state = {
k: v.values() k: v.values()
for k, v in state.items() for k, v in state.items()
if len(v.values()) > 1 if len(v.values()) > 1
} }
if event_type: if event_type:
prev_states_events = conflicted_state.get( prev_states_events = conflicted_state.get(
(event_type, state_key), [] (event_type, state_key), []
) )
prev_states = [s.event_id for s in prev_states_events] prev_states = [s.event_id for s in prev_states_events]
else: else:
prev_states = [] prev_states = []
auth_events = { auth_events = {
k: e for k, e in unconflicted_state.items() k: e for k, e in unconflicted_state.items()
if k[0] in AuthEventTypes if k[0] in AuthEventTypes
} }
try: try:
resolved_state = self._resolve_state_events( resolved_state = self._resolve_state_events(
conflicted_state, auth_events conflicted_state, auth_events
) )
except: except:
logger.exception("Failed to resolve state") logger.exception("Failed to resolve state")
raise raise
new_state = unconflicted_state new_state = unconflicted_state
new_state.update(resolved_state) new_state.update(resolved_state)
return new_state, prev_states return new_state, prev_states

View File

@ -20,7 +20,7 @@ from .appservice import (
from ._base import Cache from ._base import Cache
from .directory import DirectoryStore from .directory import DirectoryStore
from .events import EventsStore from .events import EventsStore
from .presence import PresenceStore from .presence import PresenceStore, UserPresenceState
from .profile import ProfileStore from .profile import ProfileStore
from .registration import RegistrationStore from .registration import RegistrationStore
from .room import RoomStore from .room import RoomStore
@ -45,8 +45,9 @@ from .search import SearchStore
from .tags import TagsStore from .tags import TagsStore
from .account_data import AccountDataStore from .account_data import AccountDataStore
from util.id_generators import IdGenerator, StreamIdGenerator from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
from synapse.api.constants import PresenceState
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -110,16 +111,25 @@ class DataStore(RoomMemberStore, RoomStore,
self._account_data_id_gen = StreamIdGenerator( self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id" db_conn, "account_data_max_stream_id", "stream_id"
) )
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
)
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self) self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
self._pushers_id_gen = IdGenerator("pushers", "id", self) self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) self._push_rules_stream_id_gen = ChainedIdGenerator(
self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
)
self._pushers_id_gen = StreamIdGenerator(
db_conn, "pushers", "id",
extra_tables=[("deleted_pushers", "stream_id")],
)
events_max = self._stream_id_gen.get_max_token(None) events_max = self._stream_id_gen.get_max_token()
event_cache_prefill, min_event_val = self._get_cache_dict( event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events", db_conn, "events",
entity_column="room_id", entity_column="room_id",
@ -135,13 +145,43 @@ class DataStore(RoomMemberStore, RoomStore,
"MembershipStreamChangeCache", events_max, "MembershipStreamChangeCache", events_max,
) )
account_max = self._account_data_id_gen.get_max_token(None) account_max = self._account_data_id_gen.get_max_token()
self._account_data_stream_cache = StreamChangeCache( self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max, "AccountDataAndTagsChangeCache", account_max,
) )
self.__presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict(
db_conn, "presence_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._presence_id_gen.get_max_token(),
)
self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", min_presence_val,
prefilled_cache=presence_cache_prefill
)
push_rules_prefill, push_rules_id = self._get_cache_dict(
db_conn, "push_rules_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._push_rules_stream_id_gen.get_max_token()[0],
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache", push_rules_id,
prefilled_cache=push_rules_prefill,
)
super(DataStore, self).__init__(hs) super(DataStore, self).__init__(hs)
def take_presence_startup_info(self):
active_on_startup = self.__presence_on_startup
self.__presence_on_startup = None
return active_on_startup
def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value):
# Fetch a mapping of room_id -> max stream position for "recent" rooms. # Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will # It doesn't really matter how many we get, the StreamChangeCache will
@ -161,6 +201,7 @@ class DataStore(RoomMemberStore, RoomStore,
txn = db_conn.cursor() txn = db_conn.cursor()
txn.execute(sql, (int(max_value),)) txn.execute(sql, (int(max_value),))
rows = txn.fetchall() rows = txn.fetchall()
txn.close()
cache = { cache = {
row[0]: int(row[1]) row[0]: int(row[1])
@ -174,6 +215,28 @@ class DataStore(RoomMemberStore, RoomStore,
return cache, min_val return cache, min_val
def _get_active_presence(self, db_conn):
"""Fetch non-offline presence from the database so that we can register
the appropriate time outs.
"""
sql = (
"SELECT user_id, state, last_active_ts, last_federation_update_ts,"
" last_user_sync_ts, status_msg, currently_active FROM presence_stream"
" WHERE state != ?"
)
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
rows = self.cursor_to_dict(txn)
txn.close()
for row in rows:
row["currently_active"] = bool(row["currently_active"])
return [UserPresenceState(**row) for row in rows]
@defer.inlineCallbacks @defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent): def insert_client_ip(self, user, access_token, ip, user_agent):
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())

View File

@ -18,6 +18,7 @@ from synapse.api.errors import StoreError
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache from synapse.util.caches.descriptors import Cache
from synapse.util.caches import intern_dict
import synapse.metrics import synapse.metrics
@ -26,6 +27,10 @@ from twisted.internet import defer
import sys import sys
import time import time
import threading import threading
import os
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -163,7 +168,9 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
max_entries=hs.config.event_cache_size) max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache("*stateGroupCache*", 2000) self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR
)
self._event_fetch_lock = threading.Condition() self._event_fetch_lock = threading.Condition()
self._event_fetch_list = [] self._event_fetch_list = []
@ -344,7 +351,7 @@ class SQLBaseStore(object):
""" """
col_headers = list(column[0] for column in cursor.description) col_headers = list(column[0] for column in cursor.description)
results = list( results = list(
dict(zip(col_headers, row)) for row in cursor.fetchall() intern_dict(dict(zip(col_headers, row))) for row in cursor.fetchall()
) )
return results return results
@ -766,6 +773,19 @@ class SQLBaseStore(object):
"""Executes a DELETE query on the named table, expecting to delete a """Executes a DELETE query on the named table, expecting to delete a
single row. single row.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the row with
"""
return self.runInteraction(
desc, self._simple_delete_one_txn, table, keyvalues
)
@staticmethod
def _simple_delete_one_txn(txn, table, keyvalues):
"""Executes a DELETE query on the named table, expecting to delete a
single row.
Args: Args:
table : string giving the table name table : string giving the table name
keyvalues : dict of column names and values to select the row with keyvalues : dict of column names and values to select the row with
@ -775,13 +795,11 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k, ) for k in keyvalues) " AND ".join("%s = ?" % (k, ) for k in keyvalues)
) )
def func(txn): txn.execute(sql, keyvalues.values())
txn.execute(sql, keyvalues.values()) if txn.rowcount == 0:
if txn.rowcount == 0: raise StoreError(404, "No row found")
raise StoreError(404, "No row found") if txn.rowcount > 1:
if txn.rowcount > 1: raise StoreError(500, "more than one row matched")
raise StoreError(500, "more than one row matched")
return self.runInteraction(desc, func)
@staticmethod @staticmethod
def _simple_delete_txn(txn, table, keyvalues): def _simple_delete_txn(txn, table, keyvalues):

View File

@ -83,8 +83,40 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_room", get_account_data_for_room_txn "get_account_data_for_room", get_account_data_for_room_txn
) )
def get_updated_account_data_for_user(self, user_id, stream_id, room_ids=None): def get_all_updated_account_data(self, last_global_id, last_room_id,
"""Get all the client account_data for a that's changed. current_id, limit):
"""Get all the client account_data that has changed on the server
Args:
last_global_id(int): The position to fetch from for top level data
last_room_id(int): The position to fetch from for per room data
current_id(int): The position to fetch up to.
Returns:
A deferred pair of lists of tuples of stream_id int, user_id string,
room_id string, type string, and content string.
"""
def get_updated_account_data_txn(txn):
sql = (
"SELECT stream_id, user_id, account_data_type, content"
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_global_id, current_id, limit))
global_results = txn.fetchall()
sql = (
"SELECT stream_id, user_id, room_id, account_data_type, content"
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_room_id, current_id, limit))
room_results = txn.fetchall()
return (global_results, room_results)
return self.runInteraction(
"get_all_updated_account_data_txn", get_updated_account_data_txn
)
def get_updated_account_data_for_user(self, user_id, stream_id):
"""Get all the client account_data for a that's changed for a user
Args: Args:
user_id(str): The user to get the account_data for. user_id(str): The user to get the account_data for.
@ -163,12 +195,12 @@ class AccountDataStore(SQLBaseStore):
) )
self._update_max_stream_id(txn, next_id) self._update_max_stream_id(txn, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id: with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction( yield self.runInteraction(
"add_room_account_data", add_account_data_txn, next_id "add_room_account_data", add_account_data_txn, next_id
) )
result = yield self._account_data_id_gen.get_max_token(self) result = self._account_data_id_gen.get_max_token()
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -202,12 +234,12 @@ class AccountDataStore(SQLBaseStore):
) )
self._update_max_stream_id(txn, next_id) self._update_max_stream_id(txn, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id: with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction( yield self.runInteraction(
"add_user_account_data", add_account_data_txn, next_id "add_user_account_data", add_account_data_txn, next_id
) )
result = yield self._account_data_id_gen.get_max_token(self) result = self._account_data_id_gen.get_max_token()
defer.returnValue(result) defer.returnValue(result)
def _update_max_stream_id(self, txn, next_id): def _update_max_stream_id(self, txn, next_id):

View File

@ -34,8 +34,8 @@ class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, hs):
super(ApplicationServiceStore, self).__init__(hs) super(ApplicationServiceStore, self).__init__(hs)
self.hostname = hs.hostname self.hostname = hs.hostname
self.services_cache = [] self.services_cache = ApplicationServiceStore.load_appservices(
self._populate_appservice_cache( hs.hostname,
hs.config.app_service_config_files hs.config.app_service_config_files
) )
@ -144,21 +144,23 @@ class ApplicationServiceStore(SQLBaseStore):
return rooms_for_user_matching_user_id return rooms_for_user_matching_user_id
def _load_appservice(self, as_info): @classmethod
def _load_appservice(cls, hostname, as_info, config_filename):
required_string_fields = [ required_string_fields = [
# TODO: Add id here when it's stable to release "id", "url", "as_token", "hs_token", "sender_localpart"
"url", "as_token", "hs_token", "sender_localpart"
] ]
for field in required_string_fields: for field in required_string_fields:
if not isinstance(as_info.get(field), basestring): if not isinstance(as_info.get(field), basestring):
raise KeyError("Required string field: '%s'", field) raise KeyError("Required string field: '%s' (%s)" % (
field, config_filename,
))
localpart = as_info["sender_localpart"] localpart = as_info["sender_localpart"]
if urllib.quote(localpart) != localpart: if urllib.quote(localpart) != localpart:
raise ValueError( raise ValueError(
"sender_localpart needs characters which are not URL encoded." "sender_localpart needs characters which are not URL encoded."
) )
user = UserID(localpart, self.hostname) user = UserID(localpart, hostname)
user_id = user.to_string() user_id = user.to_string()
# namespace checks # namespace checks
@ -188,25 +190,30 @@ class ApplicationServiceStore(SQLBaseStore):
namespaces=as_info["namespaces"], namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"], hs_token=as_info["hs_token"],
sender=user_id, sender=user_id,
id=as_info["id"] if "id" in as_info else as_info["as_token"], id=as_info["id"],
) )
def _populate_appservice_cache(self, config_files): @classmethod
"""Populates a cache of Application Services from the config files.""" def load_appservices(cls, hostname, config_files):
"""Returns a list of Application Services from the config files."""
if not isinstance(config_files, list): if not isinstance(config_files, list):
logger.warning( logger.warning(
"Expected %s to be a list of AS config files.", config_files "Expected %s to be a list of AS config files.", config_files
) )
return return []
# Dicts of value -> filename # Dicts of value -> filename
seen_as_tokens = {} seen_as_tokens = {}
seen_ids = {} seen_ids = {}
appservices = []
for config_file in config_files: for config_file in config_files:
try: try:
with open(config_file, 'r') as f: with open(config_file, 'r') as f:
appservice = self._load_appservice(yaml.load(f)) appservice = ApplicationServiceStore._load_appservice(
hostname, yaml.load(f), config_file
)
if appservice.id in seen_ids: if appservice.id in seen_ids:
raise ConfigError( raise ConfigError(
"Cannot reuse ID across application services: " "Cannot reuse ID across application services: "
@ -226,11 +233,12 @@ class ApplicationServiceStore(SQLBaseStore):
) )
seen_as_tokens[appservice.token] = config_file seen_as_tokens[appservice.token] = config_file
logger.info("Loaded application service: %s", appservice) logger.info("Loaded application service: %s", appservice)
self.services_cache.append(appservice) appservices.append(appservice)
except Exception as e: except Exception as e:
logger.error("Failed to load appservice from '%s'", config_file) logger.error("Failed to load appservice from '%s'", config_file)
logger.exception(e) logger.exception(e)
raise raise
return appservices
class ApplicationServiceTransactionStore(SQLBaseStore): class ApplicationServiceTransactionStore(SQLBaseStore):

View File

@ -70,13 +70,14 @@ class DirectoryStore(SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def create_room_alias_association(self, room_alias, room_id, servers): def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
""" Creates an associatin between a room alias and room_id/servers """ Creates an associatin between a room alias and room_id/servers
Args: Args:
room_alias (RoomAlias) room_alias (RoomAlias)
room_id (str) room_id (str)
servers (list) servers (list)
creator (str): Optional user_id of creator.
Returns: Returns:
Deferred Deferred
@ -87,6 +88,7 @@ class DirectoryStore(SQLBaseStore):
{ {
"room_alias": room_alias.to_string(), "room_alias": room_alias.to_string(),
"room_id": room_id, "room_id": room_id,
"creator": creator,
}, },
desc="create_room_alias_association", desc="create_room_alias_association",
) )
@ -107,6 +109,17 @@ class DirectoryStore(SQLBaseStore):
) )
self.get_aliases_for_room.invalidate((room_id,)) self.get_aliases_for_room.invalidate((room_id,))
def get_room_alias_creator(self, room_alias):
return self._simple_select_one_onecol(
table="room_aliases",
keyvalues={
"room_alias": room_alias,
},
retcol="creator",
desc="get_room_alias_creator",
allow_none=True
)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_room_alias(self, room_alias): def delete_room_alias(self, room_alias):
room_id = yield self.runInteraction( room_id = yield self.runInteraction(
@ -142,7 +155,7 @@ class DirectoryStore(SQLBaseStore):
return room_id return room_id
@cached() @cached(max_entries=5000)
def get_aliases_for_room(self, room_id): def get_aliases_for_room(self, room_id):
return self._simple_select_onecol( return self._simple_select_onecol(
"room_aliases", "room_aliases",

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from _base import SQLBaseStore from ._base import SQLBaseStore
class EndToEndKeyStore(SQLBaseStore): class EndToEndKeyStore(SQLBaseStore):

View File

@ -26,12 +26,13 @@ SUPPORTED_MODULE = {
} }
def create_engine(name): def create_engine(config):
name = config.database_config["name"]
engine_class = SUPPORTED_MODULE.get(name, None) engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class: if engine_class:
module = importlib.import_module(name) module = importlib.import_module(name)
return engine_class(module) return engine_class(module, config=config)
raise RuntimeError( raise RuntimeError(
"Unsupported database engine '%s'" % (name,) "Unsupported database engine '%s'" % (name,)

View File

@ -21,9 +21,10 @@ from ._base import IncorrectDatabaseSetup
class PostgresEngine(object): class PostgresEngine(object):
single_threaded = False single_threaded = False
def __init__(self, database_module): def __init__(self, database_module, config):
self.module = database_module self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE) self.module.extensions.register_type(self.module.extensions.UNICODE)
self.config = config
def check_database(self, txn): def check_database(self, txn):
txn.execute("SHOW SERVER_ENCODING") txn.execute("SHOW SERVER_ENCODING")
@ -44,7 +45,7 @@ class PostgresEngine(object):
) )
def prepare_database(self, db_conn): def prepare_database(self, db_conn):
prepare_database(db_conn, self) prepare_database(db_conn, self, config=self.config)
def is_deadlock(self, error): def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError): if isinstance(error, self.module.DatabaseError):

View File

@ -23,8 +23,9 @@ import struct
class Sqlite3Engine(object): class Sqlite3Engine(object):
single_threaded = True single_threaded = True
def __init__(self, database_module): def __init__(self, database_module, config):
self.module = database_module self.module = database_module
self.config = config
def check_database(self, txn): def check_database(self, txn):
pass pass
@ -38,7 +39,7 @@ class Sqlite3Engine(object):
def prepare_database(self, db_conn): def prepare_database(self, db_conn):
prepare_sqlite3_database(db_conn) prepare_sqlite3_database(db_conn)
prepare_database(db_conn, self) prepare_database(db_conn, self, config=self.config)
def is_deadlock(self, error): def is_deadlock(self, error):
return False return False

View File

@ -114,10 +114,10 @@ class EventFederationStore(SQLBaseStore):
retcol="event_id", retcol="event_id",
) )
def get_latest_events_in_room(self, room_id): def get_latest_event_ids_and_hashes_in_room(self, room_id):
return self.runInteraction( return self.runInteraction(
"get_latest_events_in_room", "get_latest_event_ids_and_hashes_in_room",
self._get_latest_events_in_room, self._get_latest_event_ids_and_hashes_in_room,
room_id, room_id,
) )
@ -132,7 +132,7 @@ class EventFederationStore(SQLBaseStore):
desc="get_latest_event_ids_in_room", desc="get_latest_event_ids_in_room",
) )
def _get_latest_events_in_room(self, txn, room_id): def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id):
sql = ( sql = (
"SELECT e.event_id, e.depth FROM events as e " "SELECT e.event_id, e.depth FROM events as e "
"INNER JOIN event_forward_extremities as f " "INNER JOIN event_forward_extremities as f "

View File

@ -27,15 +27,14 @@ class EventPushActionsStore(SQLBaseStore):
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
""" """
:param event: the event set actions for :param event: the event set actions for
:param tuples: list of tuples of (user_id, profile_tag, actions) :param tuples: list of tuples of (user_id, actions)
""" """
values = [] values = []
for uid, profile_tag, actions in tuples: for uid, actions in tuples:
values.append({ values.append({
'room_id': event.room_id, 'room_id': event.room_id,
'event_id': event.event_id, 'event_id': event.event_id,
'user_id': uid, 'user_id': uid,
'profile_tag': profile_tag,
'actions': json.dumps(actions), 'actions': json.dumps(actions),
'stream_ordering': event.internal_metadata.stream_ordering, 'stream_ordering': event.internal_metadata.stream_ordering,
'topological_ordering': event.depth, 'topological_ordering': event.depth,
@ -43,14 +42,14 @@ class EventPushActionsStore(SQLBaseStore):
'highlight': 1 if _action_has_highlight(actions) else 0, 'highlight': 1 if _action_has_highlight(actions) else 0,
}) })
for uid, _, __ in tuples: for uid, __ in tuples:
txn.call_after( txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many, self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(event.room_id, uid) (event.room_id, uid)
) )
self._simple_insert_many_txn(txn, "event_push_actions", values) self._simple_insert_many_txn(txn, "event_push_actions", values)
@cachedInlineCallbacks(num_args=3, lru=True, tree=True) @cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user( def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id self, room_id, user_id, last_read_event_id
): ):

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from _base import SQLBaseStore, _RollbackButIsFineException from ._base import SQLBaseStore, _RollbackButIsFineException
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
@ -75,8 +75,8 @@ class EventsStore(SQLBaseStore):
yield stream_orderings yield stream_orderings
stream_ordering_manager = stream_ordering_manager() stream_ordering_manager = stream_ordering_manager()
else: else:
stream_ordering_manager = yield self._stream_id_gen.get_next_mult( stream_ordering_manager = self._stream_id_gen.get_next_mult(
self, len(events_and_contexts) len(events_and_contexts)
) )
with stream_ordering_manager as stream_orderings: with stream_ordering_manager as stream_orderings:
@ -101,37 +101,23 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, context, backfilled=False, def persist_event(self, event, context,
is_new_state=True, current_state=None): is_new_state=True, current_state=None):
stream_ordering = None
if backfilled:
self.min_stream_token -= 1
stream_ordering = self.min_stream_token
if stream_ordering is None:
stream_ordering_manager = yield self._stream_id_gen.get_next(self)
else:
@contextmanager
def stream_ordering_manager():
yield stream_ordering
stream_ordering_manager = stream_ordering_manager()
try: try:
with stream_ordering_manager as stream_ordering: with self._stream_id_gen.get_next() as stream_ordering:
event.internal_metadata.stream_ordering = stream_ordering event.internal_metadata.stream_ordering = stream_ordering
yield self.runInteraction( yield self.runInteraction(
"persist_event", "persist_event",
self._persist_event_txn, self._persist_event_txn,
event=event, event=event,
context=context, context=context,
backfilled=backfilled,
is_new_state=is_new_state, is_new_state=is_new_state,
current_state=current_state, current_state=current_state,
) )
except _RollbackButIsFineException: except _RollbackButIsFineException:
pass pass
max_persisted_id = yield self._stream_id_gen.get_max_token(self) max_persisted_id = yield self._stream_id_gen.get_max_token()
defer.returnValue((stream_ordering, max_persisted_id)) defer.returnValue((stream_ordering, max_persisted_id))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -165,13 +151,38 @@ class EventsStore(SQLBaseStore):
defer.returnValue(events[0] if events else None) defer.returnValue(events[0] if events else None)
@defer.inlineCallbacks
def get_events(self, event_ids, check_redacted=True,
get_prev_content=False, allow_rejected=False):
"""Get events from the database
Args:
event_ids (list): The event_ids of the events to fetch
check_redacted (bool): If True, check if event has been redacted
and redact it.
get_prev_content (bool): If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
Returns:
Deferred : Dict from event_id to event.
"""
events = yield self._get_events(
event_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
defer.returnValue({e.event_id: e for e in events})
@log_function @log_function
def _persist_event_txn(self, txn, event, context, backfilled, def _persist_event_txn(self, txn, event, context,
is_new_state=True, current_state=None): is_new_state=True, current_state=None):
# We purposefully do this first since if we include a `current_state` # We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table # key, we *want* to update the `current_state_events` table
if current_state: if current_state:
txn.call_after(self.get_current_state_for_key.invalidate_all) txn.call_after(self._get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
@ -198,7 +209,7 @@ class EventsStore(SQLBaseStore):
return self._persist_events_txn( return self._persist_events_txn(
txn, txn,
[(event, context)], [(event, context)],
backfilled=backfilled, backfilled=False,
is_new_state=is_new_state, is_new_state=is_new_state,
) )
@ -455,7 +466,7 @@ class EventsStore(SQLBaseStore):
for event, _ in state_events_and_contexts: for event, _ in state_events_and_contexts:
if not context.rejected: if not context.rejected:
txn.call_after( txn.call_after(
self.get_current_state_for_key.invalidate, self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,) (event.room_id, event.type, event.state_key,)
) )
@ -526,6 +537,9 @@ class EventsStore(SQLBaseStore):
if not event_ids: if not event_ids:
defer.returnValue([]) defer.returnValue([])
event_id_list = event_ids
event_ids = set(event_ids)
event_map = self._get_events_from_cache( event_map = self._get_events_from_cache(
event_ids, event_ids,
check_redacted=check_redacted, check_redacted=check_redacted,
@ -535,23 +549,18 @@ class EventsStore(SQLBaseStore):
missing_events_ids = [e for e in event_ids if e not in event_map] missing_events_ids = [e for e in event_ids if e not in event_map]
if not missing_events_ids: if missing_events_ids:
defer.returnValue([ missing_events = yield self._enqueue_events(
event_map[e_id] for e_id in event_ids missing_events_ids,
if e_id in event_map and event_map[e_id] check_redacted=check_redacted,
]) get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
missing_events = yield self._enqueue_events( event_map.update(missing_events)
missing_events_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
event_map.update(missing_events)
defer.returnValue([ defer.returnValue([
event_map[e_id] for e_id in event_ids event_map[e_id] for e_id in event_id_list
if e_id in event_map and event_map[e_id] if e_id in event_map and event_map[e_id]
]) ])
@ -1064,3 +1073,48 @@ class EventsStore(SQLBaseStore):
yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
defer.returnValue(result) defer.returnValue(result)
def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached"""
# TODO: Fix race with the persit_event txn by using one of the
# stream id managers
return -self.min_stream_token
def get_all_new_events(self, last_backfill_id, last_forward_id,
current_backfill_id, current_forward_id, limit):
"""Get all the new events that have arrived at the server either as
new events or as backfilled events"""
def get_all_new_events_txn(txn):
sql = (
"SELECT e.stream_ordering, ej.internal_metadata, ej.json"
" FROM events as e"
" JOIN event_json as ej"
" ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
" WHERE ? < e.stream_ordering AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
" LIMIT ?"
)
if last_forward_id != current_forward_id:
txn.execute(sql, (last_forward_id, current_forward_id, limit))
new_forward_events = txn.fetchall()
else:
new_forward_events = []
sql = (
"SELECT -e.stream_ordering, ej.internal_metadata, ej.json"
" FROM events as e"
" JOIN event_json as ej"
" ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
" WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
" ORDER BY e.stream_ordering DESC"
" LIMIT ?"
)
if last_backfill_id != current_backfill_id:
txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
new_backfill_events = txn.fetchall()
else:
new_backfill_events = []
return (new_forward_events, new_backfill_events)
return self.runInteraction("get_all_new_events", get_all_new_events_txn)

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from _base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from twisted.internet import defer from twisted.internet import defer

Some files were not shown because too many files have changed in this diff Show More