Merge branch 'master' of git+ssh://github.com/matrix-org/synapse

This commit is contained in:
Matthew Hodgson 2016-02-10 16:27:15 +00:00
commit 7634687057
339 changed files with 5769 additions and 4015 deletions

View File

@ -51,3 +51,6 @@ Steven Hammerton <steven.hammerton at openmarket.com>
Mads Robin Christensen <mads at v42 dot dk> Mads Robin Christensen <mads at v42 dot dk>
* CentOS 7 installation instructions. * CentOS 7 installation instructions.
Florent Violleau <floviolleau at gmail dot com>
* Add Raspberry Pi installation instructions and general troubleshooting items

View File

@ -1,3 +1,72 @@
Changes in synapse v0.13.0 (2016-02-10)
=======================================
This version includes an upgrade of the schema, specifically adding an index to
the ``events`` table. This may cause synapse to pause for several minutes the
first time it is started after the upgrade.
Changes:
* Improve general performance (PR #540, #543. #544, #54, #549, #567)
* Change guest user ids to be incrementing integers (PR #550)
* Improve performance of public room list API (PR #552)
* Change profile API to omit keys rather than return null (PR #557)
* Add ``/media/r0`` endpoint prefix, which is equivalent to ``/media/v1/``
(PR #595)
Bug fixes:
* Fix bug with upgrading guest accounts where it would fail if you opened the
registration email on a different device (PR #547)
* Fix bug where unread count could be wrong (PR #568)
Changes in synapse v0.12.1-rc1 (2016-01-29)
===========================================
Features:
* Add unread notification counts in ``/sync`` (PR #456)
* Add support for inviting 3pids in ``/createRoom`` (PR #460)
* Add ability for guest accounts to upgrade (PR #462)
* Add ``/versions`` API (PR #468)
* Add ``event`` to ``/context`` API (PR #492)
* Add specific error code for invalid user names in ``/register`` (PR #499)
* Add support for push badge counts (PR #507)
* Add support for non-guest users to peek in rooms using ``/events`` (PR #510)
Changes:
* Change ``/sync`` so that guest users only get rooms they've joined (PR #469)
* Change to require unbanning before other membership changes (PR #501)
* Change default push rules to notify for all messages (PR #486)
* Change default push rules to not notify on membership changes (PR #514)
* Change default push rules in one to one rooms to only notify for events that
are messages (PR #529)
* Change ``/sync`` to reject requests with a ``from`` query param (PR #512)
* Change server manhole to use SSH rather than telnet (PR #473)
* Change server to require AS users to be registered before use (PR #487)
* Change server not to start when ASes are invalidly configured (PR #494)
* Change server to require ID and ``as_token`` to be unique for AS's (PR #496)
* Change maximum pagination limit to 1000 (PR #497)
Bug fixes:
* Fix bug where ``/sync`` didn't return when something under the leave key
changed (PR #461)
* Fix bug where we returned smaller rather than larger than requested
thumbnails when ``method=crop`` (PR #464)
* Fix thumbnails API to only return cropped thumbnails when asking for a
cropped thumbnail (PR #475)
* Fix bug where we occasionally still logged access tokens (PR #477)
* Fix bug where ``/events`` would always return immediately for guest users
(PR #480)
* Fix bug where ``/sync`` unexpectedly returned old left rooms (PR #481)
* Fix enabling and disabling push rules (PR #498)
* Fix bug where ``/register`` returned 500 when given unicode username
(PR #513)
Changes in synapse v0.12.0 (2016-01-04) Changes in synapse v0.12.0 (2016-01-04)
======================================= =======================================

View File

@ -125,6 +125,15 @@ Installing prerequisites on Mac OS X::
sudo easy_install pip sudo easy_install pip
sudo pip install virtualenv sudo pip install virtualenv
Installing prerequisites on Raspbian::
sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools sqlite3 \
libssl-dev python-virtualenv libjpeg-dev
sudo pip install --upgrade pip
sudo pip install --upgrade ndg-httpsclient
sudo pip install --upgrade virtualenv
To install the synapse homeserver run:: To install the synapse homeserver run::
virtualenv -p python2.7 ~/.synapse virtualenv -p python2.7 ~/.synapse
@ -167,8 +176,7 @@ identify itself to other Home Servers, so don't lose or delete them. It would be
wise to back them up somewhere safe. If, for whatever reason, you do need to wise to back them up somewhere safe. If, for whatever reason, you do need to
change your Home Server's keys, you may find that other Home Servers have the change your Home Server's keys, you may find that other Home Servers have the
old key cached. If you update the signing key, you should change the name of the old key cached. If you update the signing key, you should change the name of the
key in the <server name>.signing.key file (the second word, which by default is key in the <server name>.signing.key file (the second word) to something different.
, 'auto') to something different.
By default, registration of new users is disabled. You can either enable By default, registration of new users is disabled. You can either enable
registration in the config by specifying ``enable_registration: true`` registration in the config by specifying ``enable_registration: true``
@ -259,6 +267,14 @@ During setup of Synapse you need to call python2.7 directly again::
...substituting your host and domain name as appropriate. ...substituting your host and domain name as appropriate.
FreeBSD
-------
Synapse can be installed via FreeBSD Ports or Packages:
- Ports: ``cd /usr/ports/net/py-matrix-synapse && make install clean``
- Packages: ``pkg install py27-matrix-synapse``
Windows Install Windows Install
--------------- ---------------
Synapse can be installed on Cygwin. It requires the following Cygwin packages: Synapse can be installed on Cygwin. It requires the following Cygwin packages:
@ -303,6 +319,18 @@ may need to manually upgrade it::
sudo pip install --upgrade pip sudo pip install --upgrade pip
Installing may fail with ``Could not find any downloads that satisfy the requirement pymacaroons-pynacl (from matrix-synapse==0.12.0)``.
You can fix this by manually upgrading pip and virtualenv::
sudo pip install --upgrade virtualenv
You can next rerun ``virtualenv -p python2.7 synapse`` to update the virtual env.
Installing may fail during installing virtualenv with ``InsecurePlatformWarning: A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.``
You can fix this by manually installing ndg-httpsclient::
pip install --upgrade ndg-httpsclient
Installing may fail with ``mock requires setuptools>=17.1. Aborting installation``. Installing may fail with ``mock requires setuptools>=17.1. Aborting installation``.
You can fix this by upgrading setuptools:: You can fix this by upgrading setuptools::

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# Copyright 2014 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2014 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2014 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2014 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

151
contrib/graph/graph3.py Normal file
View File

@ -0,0 +1,151 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pydot
import cgi
import simplejson as json
import datetime
import argparse
from synapse.events import FrozenEvent
from synapse.util.frozenutils import unfreeze
def make_graph(file_name, room_id, file_prefix, limit):
print "Reading lines"
with open(file_name) as f:
lines = f.readlines()
print "Read lines"
events = [FrozenEvent(json.loads(line)) for line in lines]
print "Loaded events."
events.sort(key=lambda e: e.depth)
print "Sorted events"
if limit:
events = events[-int(limit):]
node_map = {}
graph = pydot.Dot(graph_name="Test")
for event in events:
t = datetime.datetime.fromtimestamp(
float(event.origin_server_ts) / 1000
).strftime('%Y-%m-%d %H:%M:%S,%f')
content = json.dumps(unfreeze(event.get_dict()["content"]), indent=4)
content = content.replace("\n", "<br/>\n")
print content
content = []
for key, value in unfreeze(event.get_dict()["content"]).items():
if value is None:
value = "<null>"
elif isinstance(value, basestring):
pass
else:
value = json.dumps(value)
content.append(
"<b>%s</b>: %s," % (
cgi.escape(key, quote=True).encode("ascii", 'xmlcharrefreplace'),
cgi.escape(value, quote=True).encode("ascii", 'xmlcharrefreplace'),
)
)
content = "<br/>\n".join(content)
print content
label = (
"<"
"<b>%(name)s </b><br/>"
"Type: <b>%(type)s </b><br/>"
"State key: <b>%(state_key)s </b><br/>"
"Content: <b>%(content)s </b><br/>"
"Time: <b>%(time)s </b><br/>"
"Depth: <b>%(depth)s </b><br/>"
">"
) % {
"name": event.event_id,
"type": event.type,
"state_key": event.get("state_key", None),
"content": content,
"time": t,
"depth": event.depth,
}
node = pydot.Node(
name=event.event_id,
label=label,
)
node_map[event.event_id] = node
graph.add_node(node)
print "Created Nodes"
for event in events:
for prev_id, _ in event.prev_events:
try:
end_node = node_map[prev_id]
except:
end_node = pydot.Node(
name=prev_id,
label="<<b>%s</b>>" % (prev_id,),
)
node_map[prev_id] = end_node
graph.add_node(end_node)
edge = pydot.Edge(node_map[event.event_id], end_node)
graph.add_edge(edge)
print "Created edges"
graph.write('%s.dot' % file_prefix, format='raw', prog='dot')
print "Created Dot"
graph.write_svg("%s.svg" % file_prefix, prog='dot')
print "Created svg"
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate a PDU graph for a given room by reading "
"from a file with line deliminated events. \n"
"Requires pydot."
)
parser.add_argument(
"-p", "--prefix", dest="prefix",
help="String to prefix output files with",
default="graph_output"
)
parser.add_argument(
"-l", "--limit",
help="Only retrieve the last N events.",
)
parser.add_argument('event_file')
parser.add_argument('room')
args = parser.parse_args()
make_graph(args.event_file, args.room, args.prefix, args.limit)

View File

@ -26,7 +26,7 @@ TOX_BIN=$WORKSPACE/.tox/py27/bin
if [[ ! -e .sytest-base ]]; then if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else else
(cd .sytest-base; git fetch) (cd .sytest-base; git fetch -p)
fi fi
rm -rf sytest rm -rf sytest
@ -52,7 +52,7 @@ RUN_POSTGRES=""
for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do
if psql synapse_jenkins_$port <<< ""; then if psql synapse_jenkins_$port <<< ""; then
RUN_POSTGRES=$RUN_POSTGRES:$port RUN_POSTGRES="$RUN_POSTGRES:$port"
cat > localhost-$port/database.yaml << EOF cat > localhost-$port/database.yaml << EOF
name: psycopg2 name: psycopg2
args: args:
@ -62,7 +62,7 @@ EOF
done done
# Run if both postgresql databases exist # Run if both postgresql databases exist
if test $RUN_POSTGRES = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then if test "$RUN_POSTGRES" = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then
echo >&2 "Running sytest with PostgreSQL"; echo >&2 "Running sytest with PostgreSQL";
$TOX_BIN/pip install psycopg2 $TOX_BIN/pip install psycopg2
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \ ./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \

View File

@ -1,5 +1,5 @@
#!/usr/bin/perl -pi #!/usr/bin/perl -pi
# Copyright 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
$copyright = <<EOT; $copyright = <<EOT;
/* Copyright 2015 OpenMarket Ltd /* Copyright 2016 OpenMarket Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#!/usr/bin/perl -pi #!/usr/bin/perl -pi
# Copyright 2014 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
$copyright = <<EOT; $copyright = <<EOT;
# Copyright 2015 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

24
scripts-dev/dump_macaroon.py Executable file
View File

@ -0,0 +1,24 @@
#!/usr/bin/env python2
import pymacaroons
import sys
if len(sys.argv) == 1:
sys.stderr.write("usage: %s macaroon [key]\n" % (sys.argv[0],))
sys.exit(1)
macaroon_string = sys.argv[1]
key = sys.argv[2] if len(sys.argv) > 2 else None
macaroon = pymacaroons.Macaroon.deserialize(macaroon_string)
print macaroon.inspect()
print ""
verifier = pymacaroons.Verifier()
verifier.satisfy_general(lambda c: True)
try:
verifier.verify(macaroon, key)
print "Signature is correct"
except Exception as e:
print e.message

View File

@ -0,0 +1,62 @@
#! /usr/bin/python
import ast
import argparse
import os
import sys
import yaml
PATTERNS_V1 = []
PATTERNS_V2 = []
RESULT = {
"v1": PATTERNS_V1,
"v2": PATTERNS_V2,
}
class CallVisitor(ast.NodeVisitor):
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
name = node.func.id
else:
return
if name == "client_path_patterns":
PATTERNS_V1.append(node.args[0].s)
elif name == "client_v2_patterns":
PATTERNS_V2.append(node.args[0].s)
def find_patterns_in_code(input_code):
input_ast = ast.parse(input_code)
visitor = CallVisitor()
visitor.visit(input_ast)
def find_patterns_in_file(filepath):
with open(filepath) as f:
find_patterns_in_code(f.read())
parser = argparse.ArgumentParser(description='Find url patterns.')
parser.add_argument(
"directories", nargs='+', metavar="DIR",
help="Directories to search for definitions"
)
args = parser.parse_args()
for directory in args.directories:
for root, dirs, files in os.walk(directory):
for filename in files:
if filename.endswith(".py"):
filepath = os.path.join(root, filename)
find_patterns_in_file(filepath)
PATTERNS_V1.sort()
PATTERNS_V2.sort()
yaml.dump(RESULT, sys.stdout, default_flow_style=False)

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -16,3 +16,4 @@ ignore =
[flake8] [flake8]
max-line-length = 90 max-line-length = 90
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# Copyright 2014 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.12.0" __version__ = "0.13.0"

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -22,8 +22,9 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import RoomID, UserID, EventID from synapse.types import Requester, RoomID, UserID, EventID
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
import logging import logging
@ -510,35 +511,14 @@ class Auth(object):
""" """
# Can optionally look elsewhere in the request (e.g. headers) # Can optionally look elsewhere in the request (e.g. headers)
try: try:
access_token = request.args["access_token"][0] user_id = yield self._get_appservice_user_id(request.args)
if user_id:
# Check for application service tokens with a user_id override
try:
app_service = yield self.store.get_app_service_by_token(
access_token
)
if not app_service:
raise KeyError
user_id = app_service.sender
if "user_id" in request.args:
user_id = request.args["user_id"][0]
if not app_service.is_interested_in_user(user_id):
raise AuthError(
403,
"Application service cannot masquerade as this user."
)
if not user_id:
raise KeyError
request.authenticated_entity = user_id request.authenticated_entity = user_id
defer.returnValue(
Requester(UserID.from_string(user_id), "", False)
)
defer.returnValue((UserID.from_string(user_id), "", False)) access_token = request.args["access_token"][0]
return
except KeyError:
pass # normal users won't have the user_id query parameter set.
user_info = yield self._get_user_by_access_token(access_token) user_info = yield self._get_user_by_access_token(access_token)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
@ -550,7 +530,8 @@ class Auth(object):
default=[""] default=[""]
)[0] )[0]
if user and access_token and ip_addr: if user and access_token and ip_addr:
self.store.insert_client_ip( preserve_context_over_fn(
self.store.insert_client_ip,
user=user, user=user,
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,
@ -564,13 +545,40 @@ class Auth(object):
request.authenticated_entity = user.to_string() request.authenticated_entity = user.to_string()
defer.returnValue((user, token_id, is_guest,)) defer.returnValue(Requester(user, token_id, is_guest))
except KeyError: except KeyError:
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
errcode=Codes.MISSING_TOKEN errcode=Codes.MISSING_TOKEN
) )
@defer.inlineCallbacks
def _get_appservice_user_id(self, request_args):
app_service = yield self.store.get_app_service_by_token(
request_args["access_token"][0]
)
if app_service is None:
defer.returnValue(None)
if "user_id" not in request_args:
defer.returnValue(app_service.sender)
user_id = request_args["user_id"][0]
if app_service.sender == user_id:
defer.returnValue(app_service.sender)
if not app_service.is_interested_in_user(user_id):
raise AuthError(
403,
"Application service cannot masquerade as this user."
)
if not (yield self.store.get_user_by_id(user_id)):
raise AuthError(
403,
"Application service has not registered this user"
)
defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_user_by_access_token(self, token): def _get_user_by_access_token(self, token):
""" Get a registered user's ID. """ Get a registered user's ID.
@ -583,7 +591,7 @@ class Auth(object):
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
try: try:
ret = yield self._get_user_from_macaroon(token) ret = yield self.get_user_from_macaroon(token)
except AuthError: except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens # TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons. # have been re-issued as macaroons.
@ -591,7 +599,7 @@ class Auth(object):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_user_from_macaroon(self, macaroon_str): def get_user_from_macaroon(self, macaroon_str):
try: try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self.validate_macaroon(macaroon, "access", False) self.validate_macaroon(macaroon, "access", False)
@ -690,6 +698,7 @@ class Auth(object):
def _look_up_user_by_access_token(self, token): def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token) ret = yield self.store.get_user_by_access_token(token)
if not ret: if not ret:
logger.warn("Unrecognised access token - not in store: %s" % (token,))
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
@ -707,6 +716,7 @@ class Auth(object):
token = request.args["access_token"][0] token = request.args["access_token"][0]
service = yield self.store.get_app_service_by_token(token) service = yield self.store.get_app_service_by_token(token)
if not service: if not service:
logger.warn("Unrecognised appservice access token: %s" % (token,))
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.", "Unrecognised access token.",

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -29,6 +29,7 @@ class Codes(object):
USER_IN_USE = "M_USER_IN_USE" USER_IN_USE = "M_USER_IN_USE"
ROOM_IN_USE = "M_ROOM_IN_USE" ROOM_IN_USE = "M_ROOM_IN_USE"
BAD_PAGINATION = "M_BAD_PAGINATION" BAD_PAGINATION = "M_BAD_PAGINATION"
BAD_STATE = "M_BAD_STATE"
UNKNOWN = "M_UNKNOWN" UNKNOWN = "M_UNKNOWN"
NOT_FOUND = "M_NOT_FOUND" NOT_FOUND = "M_NOT_FOUND"
MISSING_TOKEN = "M_MISSING_TOKEN" MISSING_TOKEN = "M_MISSING_TOKEN"
@ -42,6 +43,7 @@ class Codes(object):
EXCLUSIVE = "M_EXCLUSIVE" EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "THREEPID_IN_USE" THREEPID_IN_USE = "THREEPID_IN_USE"
INVALID_USERNAME = "M_INVALID_USERNAME"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):
@ -120,22 +122,6 @@ class AuthError(SynapseError):
super(AuthError, self).__init__(*args, **kwargs) super(AuthError, self).__init__(*args, **kwargs)
class GuestAccessError(AuthError):
"""An error raised when a there is a problem with a guest user accessing
a room"""
def __init__(self, rooms, *args, **kwargs):
self.rooms = rooms
super(GuestAccessError, self).__init__(*args, **kwargs)
def error_dict(self):
return cs_error(
self.msg,
self.errcode,
rooms=self.rooms,
)
class EventSizeError(SynapseError): class EventSizeError(SynapseError):
"""An error raised when an event is too big.""" """An error raised when an event is too big."""

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,6 +15,8 @@
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.types import UserID, RoomID from synapse.types import UserID, RoomID
import ujson as json
class Filtering(object): class Filtering(object):
@ -28,14 +30,14 @@ class Filtering(object):
return result return result
def add_user_filter(self, user_localpart, user_filter): def add_user_filter(self, user_localpart, user_filter):
self._check_valid_filter(user_filter) self.check_valid_filter(user_filter)
return self.store.add_user_filter(user_localpart, user_filter) return self.store.add_user_filter(user_localpart, user_filter)
# TODO(paul): surely we should probably add a delete_user_filter or # TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for # replace_user_filter at some point? There's no REST API specified for
# them however # them however
def _check_valid_filter(self, user_filter_json): def check_valid_filter(self, user_filter_json):
"""Check if the provided filter is valid. """Check if the provided filter is valid.
This inspects all definitions contained within the filter. This inspects all definitions contained within the filter.
@ -129,88 +131,80 @@ class Filtering(object):
class FilterCollection(object): class FilterCollection(object):
def __init__(self, filter_json): def __init__(self, filter_json):
self.filter_json = filter_json self._filter_json = filter_json
room_filter_json = self.filter_json.get("room", {}) room_filter_json = self._filter_json.get("room", {})
self.room_filter = Filter({ self._room_filter = Filter({
k: v for k, v in room_filter_json.items() k: v for k, v in room_filter_json.items()
if k in ("rooms", "not_rooms") if k in ("rooms", "not_rooms")
}) })
self.room_timeline_filter = Filter(room_filter_json.get("timeline", {})) self._room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
self.room_state_filter = Filter(room_filter_json.get("state", {})) self._room_state_filter = Filter(room_filter_json.get("state", {}))
self.room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {})) self._room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {}))
self.room_account_data = Filter(room_filter_json.get("account_data", {})) self._room_account_data = Filter(room_filter_json.get("account_data", {}))
self.presence_filter = Filter(self.filter_json.get("presence", {})) self._presence_filter = Filter(filter_json.get("presence", {}))
self.account_data = Filter(self.filter_json.get("account_data", {})) self._account_data = Filter(filter_json.get("account_data", {}))
self.include_leave = self.filter_json.get("room", {}).get( self.include_leave = filter_json.get("room", {}).get(
"include_leave", False "include_leave", False
) )
def list_rooms(self): def __repr__(self):
return self.room_filter.list_rooms() return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
def get_filter_json(self):
return self._filter_json
def timeline_limit(self): def timeline_limit(self):
return self.room_timeline_filter.limit() return self._room_timeline_filter.limit()
def presence_limit(self): def presence_limit(self):
return self.presence_filter.limit() return self._presence_filter.limit()
def ephemeral_limit(self): def ephemeral_limit(self):
return self.room_ephemeral_filter.limit() return self._room_ephemeral_filter.limit()
def filter_presence(self, events): def filter_presence(self, events):
return self.presence_filter.filter(events) return self._presence_filter.filter(events)
def filter_account_data(self, events): def filter_account_data(self, events):
return self.account_data.filter(events) return self._account_data.filter(events)
def filter_room_state(self, events): def filter_room_state(self, events):
return self.room_state_filter.filter(self.room_filter.filter(events)) return self._room_state_filter.filter(self._room_filter.filter(events))
def filter_room_timeline(self, events): def filter_room_timeline(self, events):
return self.room_timeline_filter.filter(self.room_filter.filter(events)) return self._room_timeline_filter.filter(self._room_filter.filter(events))
def filter_room_ephemeral(self, events): def filter_room_ephemeral(self, events):
return self.room_ephemeral_filter.filter(self.room_filter.filter(events)) return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
def filter_room_account_data(self, events): def filter_room_account_data(self, events):
return self.room_account_data.filter(self.room_filter.filter(events)) return self._room_account_data.filter(self._room_filter.filter(events))
class Filter(object): class Filter(object):
def __init__(self, filter_json): def __init__(self, filter_json):
self.filter_json = filter_json self.filter_json = filter_json
def list_rooms(self):
"""The list of room_id strings this filter restricts the output to
or None if the this filter doesn't list the room ids.
"""
if "rooms" in self.filter_json:
return list(set(self.filter_json["rooms"]))
else:
return None
def check(self, event): def check(self, event):
"""Checks whether the filter matches the given event. """Checks whether the filter matches the given event.
Returns: Returns:
bool: True if the event matches bool: True if the event matches
""" """
if isinstance(event, dict): sender = event.get("sender", None)
if not sender:
# Presence events have their 'sender' in content.user_id
sender = event.get("content", {}).get("user_id", None)
return self.check_fields( return self.check_fields(
event.get("room_id", None), event.get("room_id", None),
event.get("sender", None), sender,
event.get("type", None), event.get("type", None),
) )
else:
return self.check_fields(
getattr(event, "room_id", None),
getattr(event, "sender", None),
event.type,
)
def check_fields(self, room_id, sender, event_type): def check_fields(self, room_id, sender, event_type):
"""Checks whether the filter matches the given event fields. """Checks whether the filter matches the given event fields.
@ -270,3 +264,6 @@ def _matches_wildcard(actual_value, filter_value):
return actual_value.startswith(type_prefix) return actual_value.startswith(type_prefix)
else: else:
return actual_value == filter_value return actual_value == filter_value
DEFAULT_FILTER_COLLECTION = FilterCollection({})

View File

@ -1,4 +1,4 @@
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -23,5 +23,6 @@ WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content" CONTENT_REPO_PREFIX = "/_matrix/content"
SERVER_KEY_PREFIX = "/_matrix/key/v1" SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/v1" MEDIA_PREFIX = "/_matrix/media/r0"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1" APP_SERVICE_PREFIX = "/_matrix/appservice/v1"

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,3 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
sys.dont_write_bytecode = True
from synapse.python_dependencies import (
check_requirements, MissingRequirementError
) # NOQA
try:
check_requirements()
except MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (e.message,),
"To install run:",
" pip install --upgrade --force \"%s\"" % (e.dependency,),
"",
])
sys.stderr.writelines(message)
sys.exit(1)

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,61 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
from synapse.rest import ClientRestResource
sys.dont_write_bytecode = True
from synapse.python_dependencies import (
check_requirements, DEPENDENCY_LINKS, MissingRequirementError
)
if __name__ == '__main__':
try:
check_requirements()
except MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (e.message,),
"To install run:",
" pip install --upgrade --force \"%s\"" % (e.dependency,),
"",
])
sys.stderr.writelines(message)
sys.exit(1)
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import are_all_users_on_domain
from synapse.storage.prepare_database import UpgradeDatabaseException
from synapse.server import HomeServer
from twisted.internet import reactor, task, defer
from twisted.application import service
from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File
from twisted.web.server import Site, GzipEncoderFactory, Request
from synapse.http.server import JsonResource, RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import (
FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
SERVER_KEY_V2_PREFIX,
)
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse import events
from daemonize import Daemonize
import twisted.manhole.telnet
import synapse import synapse
import contextlib import contextlib
@ -77,29 +22,65 @@ import os
import re import re
import resource import resource
import subprocess import subprocess
import sys
import time import time
from synapse.config._base import ConfigError
from synapse.python_dependencies import (
check_requirements, DEPENDENCY_LINKS
)
from synapse.rest import ClientRestResource
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import are_all_users_on_domain
from synapse.storage.prepare_database import UpgradeDatabaseException
from synapse.server import HomeServer
from twisted.conch.manhole import ColoredManhole
from twisted.conch.insults import insults
from twisted.conch import manhole_ssh
from twisted.cred import checkers, portal
from twisted.internet import reactor, task, defer
from twisted.application import service
from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File
from twisted.web.server import Site, GzipEncoderFactory, Request
from synapse.http.server import RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.api.urls import (
FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
SERVER_KEY_V2_PREFIX,
)
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.federation.transport.server import TransportLayerServer
from synapse import events
from daemonize import Daemonize
logger = logging.getLogger("synapse.app.homeserver") logger = logging.getLogger("synapse.app.homeserver")
ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
def gz_wrap(r): def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()]) return EncodingResourceWrapper(r, [GzipEncoderFactory()])
class SynapseHomeServer(HomeServer): def build_resource_for_web_client(hs):
webclient_path = hs.get_config().web_client_location
def build_http_client(self):
return MatrixFederationHttpClient(self)
def build_client_resource(self):
return ClientRestResource(self)
def build_resource_for_federation(self):
return JsonResource(self)
def build_resource_for_web_client(self):
webclient_path = self.get_config().web_client_location
if not webclient_path: if not webclient_path:
try: try:
import syweb import syweb
@ -127,40 +108,8 @@ class SynapseHomeServer(HomeServer):
# return GzipFile(webclient_path) # TODO configurable? # return GzipFile(webclient_path) # TODO configurable?
return File(webclient_path) # TODO configurable? return File(webclient_path) # TODO configurable?
def build_resource_for_static_content(self):
# This is old and should go away: not going to bother adding gzip
return File(
os.path.join(os.path.dirname(synapse.__file__), "static")
)
def build_resource_for_content_repo(self):
return ContentRepoResource(
self, self.config.uploads_path, self.auth, self.content_addr
)
def build_resource_for_media_repository(self):
return MediaRepositoryResource(self)
def build_resource_for_server_key(self):
return LocalKey(self)
def build_resource_for_server_key_v2(self):
return KeyApiV2Resource(self)
def build_resource_for_metrics(self):
if self.get_config().enable_metrics:
return MetricsResource(self)
else:
return None
def build_db_pool(self):
name = self.db_config["name"]
return adbapi.ConnectionPool(
name,
**self.db_config.get("args", {})
)
class SynapseHomeServer(HomeServer):
def _listener_http(self, config, listener_config): def _listener_http(self, config, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", "") bind_address = listener_config.get("bind_address", "")
@ -170,13 +119,11 @@ class SynapseHomeServer(HomeServer):
if tls and config.no_tls: if tls and config.no_tls:
return return
metrics_resource = self.get_resource_for_metrics()
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
for name in res["names"]: for name in res["names"]:
if name == "client": if name == "client":
client_resource = self.get_client_resource() client_resource = ClientRestResource(self)
if res["compress"]: if res["compress"]:
client_resource = gz_wrap(client_resource) client_resource = gz_wrap(client_resource)
@ -185,35 +132,42 @@ class SynapseHomeServer(HomeServer):
"/_matrix/client/r0": client_resource, "/_matrix/client/r0": client_resource,
"/_matrix/client/unstable": client_resource, "/_matrix/client/unstable": client_resource,
"/_matrix/client/v2_alpha": client_resource, "/_matrix/client/v2_alpha": client_resource,
"/_matrix/client/versions": client_resource,
}) })
if name == "federation": if name == "federation":
resources.update({ resources.update({
FEDERATION_PREFIX: self.get_resource_for_federation(), FEDERATION_PREFIX: TransportLayerServer(self),
}) })
if name in ["static", "client"]: if name in ["static", "client"]:
resources.update({ resources.update({
STATIC_PREFIX: self.get_resource_for_static_content(), STATIC_PREFIX: File(
os.path.join(os.path.dirname(synapse.__file__), "static")
),
}) })
if name in ["media", "federation", "client"]: if name in ["media", "federation", "client"]:
media_repo = MediaRepositoryResource(self)
resources.update({ resources.update({
MEDIA_PREFIX: self.get_resource_for_media_repository(), MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: self.get_resource_for_content_repo(), LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path, self.auth, self.content_addr
),
}) })
if name in ["keys", "federation"]: if name in ["keys", "federation"]:
resources.update({ resources.update({
SERVER_KEY_PREFIX: self.get_resource_for_server_key(), SERVER_KEY_PREFIX: LocalKey(self),
SERVER_KEY_V2_PREFIX: self.get_resource_for_server_key_v2(), SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
}) })
if name == "webclient": if name == "webclient":
resources[WEB_CLIENT_PREFIX] = self.get_resource_for_web_client() resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
if name == "metrics" and metrics_resource: if name == "metrics" and self.get_config().enable_metrics:
resources[METRICS_PREFIX] = metrics_resource resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources) root_resource = create_resource_tree(resources)
if tls: if tls:
@ -248,10 +202,21 @@ class SynapseHomeServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listener_http(config, listener) self._listener_http(config, listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
f = twisted.manhole.telnet.ShellFactory() checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(
f.username = "matrix" matrix="rabbithole"
f.password = "rabbithole" )
f.namespace['hs'] = self
rlm = manhole_ssh.TerminalRealm()
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
ColoredManhole,
{
"__name__": "__console__",
"hs": self,
}
)
f = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
reactor.listenTCP( reactor.listenTCP(
listener["port"], listener["port"],
f, f,
@ -276,6 +241,18 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e: except IncorrectDatabaseSetup as e:
quit_with_error(e.message) quit_with_error(e.message)
def get_db_conn(self):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
self.database_engine.on_new_connection(db_conn)
return db_conn
def quit_with_error(error_string): def quit_with_error(error_string):
message_lines = error_string.split("\n") message_lines = error_string.split("\n")
@ -358,10 +335,13 @@ def change_resource_limit(soft_file_no):
soft_file_no = hard soft_file_no = hard
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard)) resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard))
logger.info("Set file limit to: %d", soft_file_no) logger.info("Set file limit to: %d", soft_file_no)
resource.setrlimit(
resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)
)
except (ValueError, resource.error) as e: except (ValueError, resource.error) as e:
logger.warn("Failed to set file limit: %s", e) logger.warn("Failed to set file or core limit: %s", e)
def setup(config_options): def setup(config_options):
@ -373,11 +353,20 @@ def setup(config_options):
Returns: Returns:
HomeServer HomeServer
""" """
try:
config = HomeServerConfig.load_config( config = HomeServerConfig.load_config(
"Synapse Homeserver", "Synapse Homeserver",
config_options, config_options,
generate_section="Homeserver" generate_section="Homeserver"
) )
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
if not config:
# If a config isn't returned, and an exception isn't raised, we're just
# generating config files and shouldn't try to continue.
sys.exit(0)
config.setup_logging() config.setup_logging()
@ -409,13 +398,7 @@ def setup(config_options):
logger.info("Preparing database: %s...", config.database_config['name']) logger.info("Preparing database: %s...", config.database_config['name'])
try: try:
db_conn = database_engine.module.connect( db_conn = hs.get_db_conn()
**{
k: v for k, v in config.database_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
database_engine.prepare_database(db_conn) database_engine.prepare_database(db_conn)
hs.run_startup_checks(db_conn, database_engine) hs.run_startup_checks(db_conn, database_engine)
@ -430,14 +413,18 @@ def setup(config_options):
logger.info("Database prepared in %s.", config.database_config['name']) logger.info("Database prepared in %s.", config.database_config['name'])
hs.setup()
hs.start_listening() hs.start_listening()
def start():
hs.get_pusherpool().start() hs.get_pusherpool().start()
hs.get_state_handler().start_caching() hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling() hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates() hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache() hs.get_replication_layer().start_get_pdu_cache()
reactor.callWhenRunning(start)
return hs return hs
@ -475,9 +462,8 @@ class SynapseRequest(Request):
) )
def get_redacted_uri(self): def get_redacted_uri(self):
return re.sub( return ACCESS_TOKEN_RE.sub(
r'(\?.*access_token=)[^&]*(.*)$', r'\1<redacted>\3',
r'\1<redacted>\2',
self.uri self.uri
) )
@ -653,7 +639,7 @@ def _resource_id(resource, path_seg):
the mapping should looks like _resource_id(A,C) = B. the mapping should looks like _resource_id(A,C) = B.
Args: Args:
resource (Resource): The *parent* Resource resource (Resource): The *parent* Resourceb
path_seg (str): The name of the child Resource to be attached. path_seg (str): The name of the child Resource to be attached.
Returns: Returns:
str: A unique string which can be a key to the child Resource. str: A unique string which can be a key to the child Resource.
@ -688,6 +674,7 @@ def run(hs):
@defer.inlineCallbacks @defer.inlineCallbacks
def phone_stats_home(): def phone_stats_home():
logger.info("Gathering stats for reporting")
now = int(hs.get_clock().time()) now = int(hs.get_clock().time())
uptime = int(now - start_time) uptime = int(now - start_time)
if uptime < 0: if uptime < 0:
@ -699,8 +686,8 @@ def run(hs):
stats["uptime_seconds"] = uptime stats["uptime_seconds"] = uptime
stats["total_users"] = yield hs.get_datastore().count_all_users() stats["total_users"] = yield hs.get_datastore().count_all_users()
all_rooms = yield hs.get_datastore().get_rooms(False) room_count = yield hs.get_datastore().get_room_count()
stats["total_room_count"] = len(all_rooms) stats["total_room_count"] = room_count
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
daily_messages = yield hs.get_datastore().count_daily_messages() daily_messages = yield hs.get_datastore().count_daily_messages()
@ -718,9 +705,12 @@ def run(hs):
if hs.config.report_stats: if hs.config.report_stats:
phone_home_task = task.LoopingCall(phone_stats_home) phone_home_task = task.LoopingCall(phone_stats_home)
logger.info("Scheduling stats reporting for 24 hour intervals")
phone_home_task.start(60 * 60 * 24, now=False) phone_home_task.start(60 * 60 * 24, now=False)
def in_thread(): def in_thread():
# Uncomment to enable tracing of log context changes.
# sys.settrace(logcontext_tracer)
with LoggingContext("run"): with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit) change_resource_limit(hs.config.soft_file_limit)
reactor.run() reactor.run()

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.config._base import ConfigError
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
@ -21,7 +22,11 @@ if __name__ == "__main__":
if action == "read": if action == "read":
key = sys.argv[2] key = sys.argv[2]
try:
config = HomeServerConfig.load_config("", sys.argv[3:]) config = HomeServerConfig.load_config("", sys.argv[3:])
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
print getattr(config, key) print getattr(config, key)
sys.exit(0) sys.exit(0)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,7 +17,6 @@ import argparse
import errno import errno
import os import os
import yaml import yaml
import sys
from textwrap import dedent from textwrap import dedent
@ -136,13 +135,20 @@ class Config(object):
results.append(getattr(cls, name)(self, *args, **kargs)) results.append(getattr(cls, name)(self, *args, **kargs))
return results return results
def generate_config(self, config_dir_path, server_name, report_stats=None): def generate_config(
self,
config_dir_path,
server_name,
is_generating_file,
report_stats=None,
):
default_config = "# vim:ft=yaml\n" default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all( default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
"default_config", "default_config",
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
server_name=server_name, server_name=server_name,
is_generating_file=is_generating_file,
report_stats=report_stats, report_stats=report_stats,
)) ))
@ -244,8 +250,10 @@ class Config(object):
server_name = config_args.server_name server_name = config_args.server_name
if not server_name: if not server_name:
print "Must specify a server_name to a generate config for." raise ConfigError(
sys.exit(1) "Must specify a server_name to a generate config for."
" Pass -H server.name."
)
if not os.path.exists(config_dir_path): if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path) os.makedirs(config_dir_path)
with open(config_path, "wb") as config_file: with open(config_path, "wb") as config_file:
@ -253,6 +261,7 @@ class Config(object):
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
server_name=server_name, server_name=server_name,
report_stats=(config_args.report_stats == "yes"), report_stats=(config_args.report_stats == "yes"),
is_generating_file=True
) )
obj.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)
config_file.write(config_bytes) config_file.write(config_bytes)
@ -266,7 +275,7 @@ class Config(object):
"If this server name is incorrect, you will need to" "If this server name is incorrect, you will need to"
" regenerate the SSL certificates" " regenerate the SSL certificates"
) )
sys.exit(0) return
else: else:
print ( print (
"Config file %r already exists. Generating any missing key" "Config file %r already exists. Generating any missing key"
@ -302,25 +311,25 @@ class Config(object):
specified_config.update(yaml_config) specified_config.update(yaml_config)
if "server_name" not in specified_config: if "server_name" not in specified_config:
sys.stderr.write("\n" + MISSING_SERVER_NAME + "\n") raise ConfigError(MISSING_SERVER_NAME)
sys.exit(1)
server_name = specified_config["server_name"] server_name = specified_config["server_name"]
_, config = obj.generate_config( _, config = obj.generate_config(
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
server_name=server_name server_name=server_name,
is_generating_file=False,
) )
config.pop("log_config") config.pop("log_config")
config.update(specified_config) config.update(specified_config)
if "report_stats" not in config: if "report_stats" not in config:
sys.stderr.write( raise ConfigError(
"\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
MISSING_REPORT_STATS_SPIEL + "\n") MISSING_REPORT_STATS_SPIEL
sys.exit(1) )
if generate_keys: if generate_keys:
obj.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)
sys.exit(0) return
obj.invoke_all("read_config", config) obj.invoke_all("read_config", config)

View File

@ -1,4 +1,4 @@
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -29,10 +29,10 @@ class CaptchaConfig(Config):
## Captcha ## ## Captcha ##
# This Home Server's ReCAPTCHA public key. # This Home Server's ReCAPTCHA public key.
recaptcha_private_key: "YOUR_PRIVATE_KEY" recaptcha_public_key: "YOUR_PUBLIC_KEY"
# This Home Server's ReCAPTCHA private key. # This Home Server's ReCAPTCHA private key.
recaptcha_public_key: "YOUR_PUBLIC_KEY" recaptcha_private_key: "YOUR_PRIVATE_KEY"
# Enables ReCaptcha checks when registering, preventing signup # Enables ReCaptcha checks when registering, preventing signup
# unless a captcha is answered. Requires a valid ReCaptcha # unless a captcha is answered. Requires a valid ReCaptcha

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -22,8 +22,14 @@ from signedjson.key import (
read_signing_keys, write_signing_keys, NACL_ED25519 read_signing_keys, write_signing_keys, NACL_ED25519
) )
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
from synapse.util.stringutils import random_string_with_symbols
import os import os
import hashlib
import logging
logger = logging.getLogger(__name__)
class KeyConfig(Config): class KeyConfig(Config):
@ -40,9 +46,29 @@ class KeyConfig(Config):
config["perspectives"] config["perspectives"]
) )
def default_config(self, config_dir_path, server_name, **kwargs): self.macaroon_secret_key = config.get(
"macaroon_secret_key", self.registration_shared_secret
)
if not self.macaroon_secret_key:
# Unfortunately, there are people out there that don't have this
# set. Lets just be "nice" and derive one from their secret key.
logger.warn("Config is missing missing macaroon_secret_key")
seed = self.signing_key[0].seed
self.macaroon_secret_key = hashlib.sha256(seed)
def default_config(self, config_dir_path, server_name, is_generating_file=False,
**kwargs):
base_key_name = os.path.join(config_dir_path, server_name) base_key_name = os.path.join(config_dir_path, server_name)
if is_generating_file:
macaroon_secret_key = random_string_with_symbols(50)
else:
macaroon_secret_key = None
return """\ return """\
macaroon_secret_key: "%(macaroon_secret_key)s"
## Signing Keys ## ## Signing Keys ##
# Path to the signing key to sign messages with # Path to the signing key to sign messages with

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -23,22 +23,23 @@ from distutils.util import strtobool
class RegistrationConfig(Config): class RegistrationConfig(Config):
def read_config(self, config): def read_config(self, config):
self.disable_registration = not bool( self.enable_registration = bool(
strtobool(str(config["enable_registration"])) strtobool(str(config["enable_registration"]))
) )
if "disable_registration" in config: if "disable_registration" in config:
self.disable_registration = bool( self.enable_registration = not bool(
strtobool(str(config["disable_registration"])) strtobool(str(config["disable_registration"]))
) )
self.registration_shared_secret = config.get("registration_shared_secret") self.registration_shared_secret = config.get("registration_shared_secret")
self.macaroon_secret_key = config.get("macaroon_secret_key")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
self.allow_guest_access = config.get("allow_guest_access", False) self.allow_guest_access = config.get("allow_guest_access", False)
def default_config(self, **kwargs): def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50) registration_shared_secret = random_string_with_symbols(50)
macaroon_secret_key = random_string_with_symbols(50)
return """\ return """\
## Registration ## ## Registration ##
@ -49,8 +50,6 @@ class RegistrationConfig(Config):
# secret, even if registration is otherwise disabled. # secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s" registration_shared_secret: "%(registration_shared_secret)s"
macaroon_secret_key: "%(macaroon_secret_key)s"
# Set the number of bcrypt rounds used to generate password hash. # Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash. # Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12. # The default number of rounds is 12.
@ -60,6 +59,12 @@ class RegistrationConfig(Config):
# participate in rooms hosted on this server which have been made # participate in rooms hosted on this server which have been made
# accessible to anonymous users. # accessible to anonymous users.
allow_guest_access: False allow_guest_access: False
# The list of identity servers trusted to verify third party
# identifiers by this server.
trusted_third_party_id_servers:
- matrix.org
- vector.im
""" % locals() """ % locals()
def add_arguments(self, parser): def add_arguments(self, parser):
@ -71,6 +76,6 @@ class RegistrationConfig(Config):
def read_arguments(self, args): def read_arguments(self, args):
if args.enable_registration is not None: if args.enable_registration is not None:
self.disable_registration = not bool( self.enable_registration = bool(
strtobool(str(args.enable_registration)) strtobool(str(args.enable_registration))
) )

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -200,7 +200,7 @@ class ServerConfig(Config):
- names: [federation] - names: [federation]
compress: false compress: false
# Turn on the twisted telnet manhole service on localhost on the given # Turn on the twisted ssh manhole service on localhost on the given
# port. # port.
# - port: 9000 # - port: 9000
# bind_address: 127.0.0.1 # bind_address: 127.0.0.1

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,6 +18,10 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import (
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
preserve_fn
)
from twisted.internet import defer from twisted.internet import defer
@ -142,6 +146,8 @@ class Keyring(object):
for server_name, _ in server_and_json for server_name, _ in server_and_json
} }
with PreserveLoggingContext():
# We want to wait for any previous lookups to complete before # We want to wait for any previous lookups to complete before
# proceeding. # proceeding.
wait_on_deferred = self.wait_for_previous_lookups( wait_on_deferred = self.wait_for_previous_lookups(
@ -175,7 +181,8 @@ class Keyring(object):
# Pass those keys to handle_key_deferred so that the json object # Pass those keys to handle_key_deferred so that the json object
# signatures can be verified # signatures can be verified
return [ return [
handle_key_deferred( preserve_context_over_fn(
handle_key_deferred,
group_id_to_group[g_id], group_id_to_group[g_id],
deferreds[g_id], deferreds[g_id],
) )
@ -198,12 +205,13 @@ class Keyring(object):
if server_name in self.key_downloads if server_name in self.key_downloads
] ]
if wait_on: if wait_on:
with PreserveLoggingContext():
yield defer.DeferredList(wait_on) yield defer.DeferredList(wait_on)
else: else:
break break
for server_name, deferred in server_to_deferred.items(): for server_name, deferred in server_to_deferred.items():
d = ObservableDeferred(deferred) d = ObservableDeferred(preserve_context_over_deferred(deferred))
self.key_downloads[server_name] = d self.key_downloads[server_name] = d
def rm(r, server_name): def rm(r, server_name):
@ -244,6 +252,7 @@ class Keyring(object):
for group in group_id_to_group.values(): for group in group_id_to_group.values():
for key_id in group.key_ids: for key_id in group.key_ids:
if key_id in merged_results[group.server_name]: if key_id in merged_results[group.server_name]:
with PreserveLoggingContext():
group_id_to_deferred[group.group_id].callback(( group_id_to_deferred[group.group_id].callback((
group.group_id, group.group_id,
group.server_name, group.server_name,
@ -504,7 +513,7 @@ class Keyring(object):
yield defer.gatherResults( yield defer.gatherResults(
[ [
self.store_keys( preserve_fn(self.store_keys)(
server_name=key_server_name, server_name=key_server_name,
from_server=server_name, from_server=server_name,
verify_keys=verify_keys, verify_keys=verify_keys,
@ -573,7 +582,7 @@ class Keyring(object):
yield defer.gatherResults( yield defer.gatherResults(
[ [
self.store.store_server_keys_json( preserve_fn(self.store.store_server_keys_json)(
server_name=server_name, server_name=server_name,
key_id=key_id, key_id=key_id,
from_server=server_name, from_server=server_name,
@ -675,7 +684,7 @@ class Keyring(object):
# TODO(markjh): Store whether the keys have expired. # TODO(markjh): Store whether the keys have expired.
yield defer.gatherResults( yield defer.gatherResults(
[ [
self.store.store_server_verify_key( preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key server_name, server_name, key.time_added, key
) )
for key_id, key in verify_keys.items() for key_id, key in verify_keys.items()

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -117,6 +117,15 @@ class EventBase(object):
def __set__(self, instance, value): def __set__(self, instance, value):
raise AttributeError("Unrecognized attribute %s" % (instance,)) raise AttributeError("Unrecognized attribute %s" % (instance,))
def __getitem__(self, field):
return self._event_dict[field]
def __contains__(self, field):
return field in self._event_dict
def items(self):
return self._event_dict.items()
class FrozenEvent(EventBase): class FrozenEvent(EventBase):
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None): def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -20,3 +20,4 @@ class EventContext(object):
self.current_state = current_state self.current_state = current_state
self.state_group = None self.state_group = None
self.rejected = False self.rejected = False
self.push_actions = []

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,15 +17,10 @@
""" """
from .replication import ReplicationLayer from .replication import ReplicationLayer
from .transport import TransportLayer from .transport.client import TransportLayerClient
def initialize_http_replication(homeserver): def initialize_http_replication(homeserver):
transport = TransportLayer( transport = TransportLayerClient(homeserver)
homeserver,
homeserver.hostname,
server=homeserver.get_resource_for_federation(),
client=homeserver.get_http_client()
)
return ReplicationLayer(homeserver, transport) return ReplicationLayer(homeserver, transport)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -57,7 +57,7 @@ class FederationClient(FederationBase):
cache_name="get_pdu_cache", cache_name="get_pdu_cache",
clock=self._clock, clock=self._clock,
max_len=1000, max_len=1000,
expiry_ms=120*1000, expiry_ms=120 * 1000,
reset_expiry_on_get=False, reset_expiry_on_get=False,
) )

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -126,10 +126,8 @@ class FederationServer(FederationBase):
results = [] results = []
for pdu in pdu_list: for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu)
try: try:
yield d yield self._handle_new_pdu(transaction.origin, pdu)
results.append({}) results.append({})
except FederationError as e: except FederationError as e:
self.send_failure(e, transaction.origin) self.send_failure(e, transaction.origin)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -54,8 +54,6 @@ class ReplicationLayer(FederationClient, FederationServer):
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self.transport_layer = transport_layer self.transport_layer = transport_layer
self.transport_layer.register_received_handler(self)
self.transport_layer.register_request_handler(self)
self.federation_client = self self.federation_client = self

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -103,7 +103,6 @@ class TransactionQueue(object):
else: else:
return not destination.startswith("localhost") return not destination.startswith("localhost")
@defer.inlineCallbacks
def enqueue_pdu(self, pdu, destinations, order): def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have # We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus # a transaction in progress. If we do, stick it in the pending_pdus
@ -141,8 +140,6 @@ class TransactionQueue(object):
deferreds.append(deferred) deferreds.append(deferred)
yield defer.DeferredList(deferreds, consumeErrors=True)
# NO inlineCallbacks # NO inlineCallbacks
def enqueue_edu(self, edu): def enqueue_edu(self, edu):
destination = edu.destination destination = edu.destination

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -20,55 +20,3 @@ By default this is done over HTTPS (and all home servers are required to
support HTTPS), however individual pairings of servers may decide to support HTTPS), however individual pairings of servers may decide to
communicate over a different (albeit still reliable) protocol. communicate over a different (albeit still reliable) protocol.
""" """
from .server import TransportLayerServer
from .client import TransportLayerClient
from synapse.util.ratelimitutils import FederationRateLimiter
class TransportLayer(TransportLayerServer, TransportLayerClient):
"""This is a basic implementation of the transport layer that translates
transactions and other requests to/from HTTP.
Attributes:
server_name (str): Local home server host
server (synapse.http.server.HttpServer): the http server to
register listeners on
client (synapse.http.client.HttpClient): the http client used to
send requests
request_handler (TransportRequestHandler): The handler to fire when we
receive requests for data.
received_handler (TransportReceivedHandler): The handler to fire when
we receive data.
"""
def __init__(self, homeserver, server_name, server, client):
"""
Args:
server_name (str): Local home server host
server (synapse.protocol.http.HttpServer): the http server to
register listeners on
client (synapse.protocol.http.HttpClient): the http client used to
send requests
"""
self.keyring = homeserver.get_keyring()
self.clock = homeserver.get_clock()
self.server_name = server_name
self.server = server
self.client = client
self.request_handler = None
self.received_handler = None
self.ratelimiter = FederationRateLimiter(
self.clock,
window_size=homeserver.config.federation_rc_window_size,
sleep_limit=homeserver.config.federation_rc_sleep_limit,
sleep_msec=homeserver.config.federation_rc_sleep_delay,
reject_limit=homeserver.config.federation_rc_reject_limit,
concurrent_requests=homeserver.config.federation_rc_concurrent,
)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
class TransportLayerClient(object): class TransportLayerClient(object):
"""Sends federation HTTP requests to other servers""" """Sends federation HTTP requests to other servers"""
def __init__(self, hs):
self.server_name = hs.hostname
self.client = hs.get_http_client()
@log_function @log_function
def get_room_state(self, destination, room_id, event_id): def get_room_state(self, destination, room_id, event_id):
""" Requests all state for a given room from the given server at the """ Requests all state for a given room from the given server at the

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,7 +17,8 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.util.logutils import log_function from synapse.http.server import JsonResource
from synapse.util.ratelimitutils import FederationRateLimiter
import functools import functools
import logging import logging
@ -28,9 +29,41 @@ import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TransportLayerServer(object): class TransportLayerServer(JsonResource):
"""Handles incoming federation HTTP requests""" """Handles incoming federation HTTP requests"""
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
super(TransportLayerServer, self).__init__(hs)
self.authenticator = Authenticator(hs)
self.ratelimiter = FederationRateLimiter(
self.clock,
window_size=hs.config.federation_rc_window_size,
sleep_limit=hs.config.federation_rc_sleep_limit,
sleep_msec=hs.config.federation_rc_sleep_delay,
reject_limit=hs.config.federation_rc_reject_limit,
concurrent_requests=hs.config.federation_rc_concurrent,
)
self.register_servlets()
def register_servlets(self):
register_servlets(
self.hs,
resource=self,
ratelimiter=self.ratelimiter,
authenticator=self.authenticator,
)
class Authenticator(object):
def __init__(self, hs):
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks @defer.inlineCallbacks
def authenticate_request(self, request): def authenticate_request(self, request):
@ -98,37 +131,9 @@ class TransportLayerServer(object):
defer.returnValue((origin, content)) defer.returnValue((origin, content))
@log_function
def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data.
Args:
handler (TransportReceivedHandler)
"""
FederationSendServlet(
handler,
authenticator=self,
ratelimiter=self.ratelimiter,
server_name=self.server_name,
).register(self.server)
@log_function
def register_request_handler(self, handler):
""" Register a handler that will be fired when we get asked for data.
Args:
handler (TransportRequestHandler)
"""
for servletclass in SERVLET_CLASSES:
servletclass(
handler,
authenticator=self,
ratelimiter=self.ratelimiter,
).register(self.server)
class BaseFederationServlet(object): class BaseFederationServlet(object):
def __init__(self, handler, authenticator, ratelimiter): def __init__(self, handler, authenticator, ratelimiter, server_name):
self.handler = handler self.handler = handler
self.authenticator = authenticator self.authenticator = authenticator
self.ratelimiter = ratelimiter self.ratelimiter = ratelimiter
@ -172,7 +177,9 @@ class FederationSendServlet(BaseFederationServlet):
PATH = "/send/([^/]*)/" PATH = "/send/([^/]*)/"
def __init__(self, handler, server_name, **kwargs): def __init__(self, handler, server_name, **kwargs):
super(FederationSendServlet, self).__init__(handler, **kwargs) super(FederationSendServlet, self).__init__(
handler, server_name=server_name, **kwargs
)
self.server_name = server_name self.server_name = server_name
# This is when someone is trying to send us a bunch of data. # This is when someone is trying to send us a bunch of data.
@ -432,6 +439,7 @@ class On3pidBindServlet(BaseFederationServlet):
SERVLET_CLASSES = ( SERVLET_CLASSES = (
FederationSendServlet,
FederationPullServlet, FederationPullServlet,
FederationEventServlet, FederationEventServlet,
FederationStateServlet, FederationStateServlet,
@ -451,3 +459,13 @@ SERVLET_CLASSES = (
FederationThirdPartyInviteExchangeServlet, FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet, On3pidBindServlet,
) )
def register_servlets(hs, resource, authenticator, ratelimiter):
for servletclass in SERVLET_CLASSES:
servletclass(
handler=hs.get_replication_layer(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -19,6 +19,7 @@ from synapse.api.errors import LimitExceededError, SynapseError, AuthError
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, RoomAlias from synapse.types import UserID, RoomAlias
from synapse.push.action_generator import ActionGenerator
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
@ -52,22 +53,51 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_guest=False, def _filter_events_for_clients(self, user_tuples, events, event_id_to_state):
require_all_visible_for_guests=True): """ Returns dict of user_id -> list of events that user is allowed to
# Assumes that user has at some point joined the room if not is_guest. see.
"""
forgotten = yield defer.gatherResults([
self.store.who_forgot_in_room(
room_id,
)
for room_id in frozenset(e.room_id for e in events)
], consumeErrors=True)
# Set of membership event_ids that have been forgotten
event_id_forgotten = frozenset(
row["event_id"] for rows in forgotten for row in rows
)
def allowed(event, user_id, is_peeking):
state = event_id_to_state[event.event_id]
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared")
else:
visibility = "shared"
def allowed(event, membership, visibility):
if visibility == "world_readable": if visibility == "world_readable":
return True return True
if is_guest: if is_peeking:
return False return False
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
if membership_event.event_id in event_id_forgotten:
membership = None
else:
membership = membership_event.membership
else:
membership = None
if membership == Membership.JOIN: if membership == Membership.JOIN:
return True return True
if event.type == EventTypes.RoomHistoryVisibility: if event.type == EventTypes.RoomHistoryVisibility:
return not is_guest return not is_peeking
if visibility == "shared": if visibility == "shared":
return True return True
@ -78,54 +108,30 @@ class BaseHandler(object):
return True return True
event_id_to_state = yield self.store.get_state_for_events( defer.returnValue({
frozenset(e.event_id for e in events), user_id: [
types=( event
for event in events
if allowed(event, user_id, is_peeking)
]
for user_id, is_peeking in user_tuples
})
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_peeking=False):
# Assumes that user has at some point joined the room if not is_guest.
types = (
(EventTypes.RoomHistoryVisibility, ""), (EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id), (EventTypes.Member, user_id),
) )
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=types
) )
res = yield self._filter_events_for_clients(
events_to_return = [] [(user_id, is_peeking)], events, event_id_to_state
for event in events:
state = event_id_to_state[event.event_id]
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
was_forgotten_at_event = yield self.store.was_forgotten_at(
membership_event.state_key,
membership_event.room_id,
membership_event.event_id
) )
if was_forgotten_at_event: defer.returnValue(res.get(user_id, []))
membership = None
else:
membership = membership_event.membership
else:
membership = None
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared")
else:
visibility = "shared"
should_include = allowed(event, membership, visibility)
if should_include:
events_to_return.append(event)
if (require_all_visible_for_guests
and is_guest
and len(events_to_return) < len(events)):
# This indicates that some events in the requested range were not
# visible to guest users. To be safe, we reject the entire request,
# so that we don't have to worry about interpreting visibility
# boundaries.
raise AuthError(403, "User %s does not have permission" % (
user_id
))
defer.returnValue(events_to_return)
def ratelimit(self, user_id): def ratelimit(self, user_id):
time_now = self.clock.time() time_now = self.clock.time()
@ -136,7 +142,7 @@ class BaseHandler(object):
) )
if not allowed: if not allowed:
raise LimitExceededError( raise LimitExceededError(
retry_after_ms=int(1000*(time_allowed - time_now)), retry_after_ms=int(1000 * (time_allowed - time_now)),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -182,11 +188,9 @@ class BaseHandler(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_new_client_event(self, event, context, extra_destinations=[], def handle_new_client_event(self, event, context, extra_users=[]):
extra_users=[], suppress_auth=False):
# We now need to go and hit out to wherever we need to hit out to. # We now need to go and hit out to wherever we need to hit out to.
if not suppress_auth:
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
yield self.maybe_kick_guest_users(event, context.current_state.values()) yield self.maybe_kick_guest_users(event, context.current_state.values())
@ -260,11 +264,16 @@ class BaseHandler(object):
"You don't have permission to redact events" "You don't have permission to redact events"
) )
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
event, context, self
)
(event_stream_id, max_stream_id) = yield self.store.persist_event( (event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context event, context=context
) )
destinations = set(extra_destinations) destinations = set()
for k, s in context.current_state.items(): for k, s in context.current_state.items():
try: try:
if k[0] == EventTypes.Member: if k[0] == EventTypes.Member:
@ -279,19 +288,11 @@ class BaseHandler(object):
with PreserveLoggingContext(): with PreserveLoggingContext():
# Don't block waiting on waking up all the listeners. # Don't block waiting on waking up all the listeners.
notify_d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=extra_users extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
notify_d.addErrback(log_failure)
# If invite, remove room_state from unsigned before sending. # If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None) event.unsigned.pop("invite_room_state", None)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -408,7 +408,7 @@ class AuthHandler(BaseHandler):
macaroon = pymacaroons.Macaroon.deserialize(login_token) macaroon = pymacaroons.Macaroon.deserialize(login_token)
auth_api = self.hs.get_auth() auth_api = self.hs.get_auth()
auth_api.validate_macaroon(macaroon, "login", True) auth_api.validate_macaroon(macaroon, "login", True)
return self._get_user_from_macaroon(macaroon) return self.get_user_from_macaroon(macaroon)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN) raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
@ -421,7 +421,7 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon return macaroon
def _get_user_from_macaroon(self, macaroon): def get_user_from_macaroon(self, macaroon):
user_prefix = "user_id = " user_prefix = "user_id = "
for caveat in macaroon.caveats: for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix): if caveat.caveat_id.startswith(user_prefix):

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -175,8 +175,8 @@ class DirectoryHandler(BaseHandler):
# If this server is in the list of servers, return it first. # If this server is in the list of servers, return it first.
if self.server_name in servers: if self.server_name in servers:
servers = ( servers = (
[self.server_name] [self.server_name] +
+ [s for s in servers if s != self.server_name] [s for s in servers if s != self.server_name]
) )
else: else:
servers = list(servers) servers = list(servers)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.types import UserID from synapse.types import UserID
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.util.logcontext import preserve_context_over_fn
from ._base import BaseHandler from ._base import BaseHandler
@ -29,15 +30,17 @@ logger = logging.getLogger(__name__)
def started_user_eventstream(distributor, user): def started_user_eventstream(distributor, user):
return distributor.fire("started_user_eventstream", user) return preserve_context_over_fn(
distributor.fire,
"started_user_eventstream", user
)
def stopped_user_eventstream(distributor, user): def stopped_user_eventstream(distributor, user):
return distributor.fire("stopped_user_eventstream", user) return preserve_context_over_fn(
distributor.fire,
"stopped_user_eventstream", user
def user_joined_room(distributor, user, room_id): )
return distributor.fire("user_joined_room", user, room_id)
class EventStreamHandler(BaseHandler): class EventStreamHandler(BaseHandler):
@ -117,10 +120,10 @@ class EventStreamHandler(BaseHandler):
@log_function @log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0, def get_stream(self, auth_user_id, pagin_config, timeout=0,
as_client_event=True, affect_presence=True, as_client_event=True, affect_presence=True,
only_room_events=False, room_id=None, is_guest=False): only_keys=None, room_id=None, is_guest=False):
"""Fetches the events stream for a given user. """Fetches the events stream for a given user.
If `only_room_events` is `True` only room events will be returned. If `only_keys` is not None, events from keys will be sent down.
""" """
auth_user = UserID.from_string(auth_user_id) auth_user = UserID.from_string(auth_user_id)
@ -134,15 +137,12 @@ class EventStreamHandler(BaseHandler):
# Add some randomness to this value to try and mitigate against # Add some randomness to this value to try and mitigate against
# thundering herds on restart. # thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
if is_guest:
yield user_joined_room(self.distributor, auth_user, room_id)
events, tokens = yield self.notifier.get_events_for( events, tokens = yield self.notifier.get_events_for(
auth_user, pagin_config, timeout, auth_user, pagin_config, timeout,
only_room_events=only_room_events, only_keys=only_keys,
is_guest=is_guest, guest_room_id=room_id is_guest=is_guest, explicit_room_id=room_id
) )
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -36,6 +36,8 @@ from synapse.events.utils import prune_event
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from synapse.push.action_generator import ActionGenerator
from twisted.internet import defer from twisted.internet import defer
import itertools import itertools
@ -219,19 +221,11 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user) extra_users.append(target_user)
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=extra_users extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
prev_state = context.current_state.get((event.type, event.state_key)) prev_state = context.current_state.get((event.type, event.state_key))
@ -635,19 +629,11 @@ class FederationHandler(BaseHandler):
) )
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=[joinee] extra_users=[joinee]
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
logger.debug("Finished joining %s to %s", joinee, room_id) logger.debug("Finished joining %s to %s", joinee, room_id)
finally: finally:
room_queue = self.room_queues[room_id] room_queue = self.room_queues[room_id]
@ -722,18 +708,10 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user) extra_users.append(target_user)
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users event, event_stream_id, max_stream_id, extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN: if event.content["membership"] == Membership.JOIN:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
@ -803,19 +781,11 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(event.state_key) target_user = UserID.from_string(event.state_key)
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=[target_user], extra_users=[target_user],
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -940,18 +910,10 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user) extra_users.append(target_user)
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users event, event_stream_id, max_stream_id, extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
new_pdu = event new_pdu = event
destinations = set() destinations = set()
@ -1105,6 +1067,12 @@ class FederationHandler(BaseHandler):
auth_events=auth_events, auth_events=auth_events,
) )
if not backfilled and not event.internal_metadata.is_outlier():
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
event, context, self
)
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, event,
context=context, context=context,
@ -1178,7 +1146,13 @@ class FederationHandler(BaseHandler):
try: try:
self.auth.check(e, auth_events=auth_for_e) self.auth.check(e, auth_events=auth_for_e)
except AuthError as err: except SynapseError as err:
# we may get SynapseErrors here as well as AuthErrors. For
# instance, there are a couple of (ancient) events in some
# rooms whose senders do not have the correct sigil; these
# cause SynapseErrors in auth.check. We don't want to give up
# the attempt to federate altogether in such cases.
logger.warn( logger.warn(
"Rejecting %s because %s", "Rejecting %s because %s",
e.event_id, err.msg e.event_id, err.msg
@ -1684,7 +1658,7 @@ class FederationHandler(BaseHandler):
self.auth.check(event, context.current_state) self.auth.check(event, context.current_state)
yield self._validate_keyserver(event, auth_events=context.current_state) yield self._validate_keyserver(event, auth_events=context.current_state)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context) yield member_handler.send_membership_event(event, context)
else: else:
destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)]) destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)])
yield self.replication_layer.forward_third_party_invite( yield self.replication_layer.forward_third_party_invite(
@ -1713,7 +1687,7 @@ class FederationHandler(BaseHandler):
# TODO: Make sure the signatures actually are correct. # TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures) event.signatures.update(returned_invite.signatures)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context) yield member_handler.send_membership_event(event, context)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_display_name_to_third_party_invite(self, event_dict, event, context): def add_display_name_to_third_party_invite(self, event_dict, event, context):

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -36,14 +36,15 @@ class IdentityHandler(BaseHandler):
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers)
self.trust_any_id_server_just_for_testing_do_not_use = (
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
)
@defer.inlineCallbacks @defer.inlineCallbacks
def threepid_from_creds(self, creds): def threepid_from_creds(self, creds):
yield run_on_reactor() yield run_on_reactor()
# XXX: make this configurable!
# trustedIdServers = ['matrix.org', 'localhost:8090']
trustedIdServers = ['matrix.org', 'vector.im']
if 'id_server' in creds: if 'id_server' in creds:
id_server = creds['id_server'] id_server = creds['id_server']
elif 'idServer' in creds: elif 'idServer' in creds:
@ -58,7 +59,16 @@ class IdentityHandler(BaseHandler):
else: else:
raise SynapseError(400, "No client_secret in creds") raise SynapseError(400, "No client_secret in creds")
if id_server not in trustedIdServers: if id_server not in self.trusted_id_servers:
if self.trust_any_id_server_just_for_testing_do_not_use:
logger.warn(
"Trusting untrustworthy ID server %r even though it isn't"
" in the trusted id list for testing because"
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
" is set in the config",
id_server,
)
else:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' + logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', id_server) 'credentials', id_server)
defer.returnValue(None) defer.returnValue(None)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError, AuthError, Codes from synapse.api.errors import AuthError, Codes
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
@ -78,21 +78,20 @@ class MessageHandler(BaseHandler):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, user_id=None, room_id=None, pagin_config=None, def get_messages(self, requester, room_id=None, pagin_config=None,
as_client_event=True, is_guest=False): as_client_event=True):
"""Get messages in a room. """Get messages in a room.
Args: Args:
user_id (str): The user requesting messages. requester (Requester): The user requesting messages.
room_id (str): The room they want messages from. room_id (str): The room they want messages from.
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any. config rules to apply, if any.
as_client_event (bool): True to get events in client-server format. as_client_event (bool): True to get events in client-server format.
is_guest (bool): Whether the requesting user is a guest (as opposed
to a fully registered user).
Returns: Returns:
dict: Pagination API results dict: Pagination API results
""" """
user_id = requester.user.to_string()
data_source = self.hs.get_event_sources().sources["room"] data_source = self.hs.get_event_sources().sources["room"]
if pagin_config.from_token: if pagin_config.from_token:
@ -106,8 +105,6 @@ class MessageHandler(BaseHandler):
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
room_token = RoomStreamToken.parse(room_token) room_token = RoomStreamToken.parse(room_token)
if room_token.topological is None:
raise SynapseError(400, "Invalid token")
pagin_config.from_token = pagin_config.from_token.copy_and_replace( pagin_config.from_token = pagin_config.from_token.copy_and_replace(
"room_key", str(room_token) "room_key", str(room_token)
@ -115,36 +112,37 @@ class MessageHandler(BaseHandler):
source_config = pagin_config.get_source_config("room") source_config = pagin_config.get_source_config("room")
if not is_guest: membership, member_event_id = yield self._check_in_room_or_world_readable(
member_event = yield self.auth.check_user_was_in_room(room_id, user_id) room_id, user_id
if member_event.membership == Membership.LEAVE: )
if source_config.direction == 'b':
# if we're going backwards, we might need to backfill. This
# requires that we have a topo token.
if room_token.topological:
max_topo = room_token.topological
else:
max_topo = yield self.store.get_max_topological_token_for_stream_and_room(
room_id, room_token.stream
)
if membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before # If they have left the room then clamp the token to be before
# they left the room. # they left the room, to save the effort of loading from the
# If they're a guest, we'll just 403 them if they're asking for # database.
# events they can't see.
leave_token = yield self.store.get_topological_token_for_event( leave_token = yield self.store.get_topological_token_for_event(
member_event.event_id member_event_id
) )
leave_token = RoomStreamToken.parse(leave_token) leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < room_token.topological: if leave_token.topological < max_topo:
source_config.from_key = str(leave_token) source_config.from_key = str(leave_token)
if source_config.direction == "f":
if source_config.to_key is None:
source_config.to_key = str(leave_token)
else:
to_token = RoomStreamToken.parse(source_config.to_key)
if leave_token.topological < to_token.topological:
source_config.to_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill( yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, room_token.topological room_id, max_topo
) )
user = UserID.from_string(user_id)
events, next_key = yield data_source.get_pagination_rows( events, next_key = yield data_source.get_pagination_rows(
user, source_config, room_id requester.user, source_config, room_id
) )
next_token = pagin_config.from_token.copy_and_replace( next_token = pagin_config.from_token.copy_and_replace(
@ -158,7 +156,11 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(), "end": next_token.to_string(),
}) })
events = yield self._filter_events_for_client(user_id, events, is_guest=is_guest) events = yield self._filter_events_for_client(
user_id,
events,
is_peeking=(member_event_id is None),
)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
@ -174,30 +176,25 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True, def create_event(self, event_dict, token_id=None, txn_id=None):
token_id=None, txn_id=None, is_guest=False): """
""" Given a dict from a client, create and handle a new event. Given a dict from a client, create a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events, Creates an FrozenEvent object, filling out auth_events, prev_events,
etc. etc.
Adds display names to Join membership events. Adds display names to Join membership events.
Persists and notifies local clients and federation.
Args: Args:
event_dict (dict): An entire event event_dict (dict): An entire event
Returns:
Tuple of created event (FrozenEvent), Context
""" """
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
self.validator.validate_new(builder) self.validator.validate_new(builder)
if ratelimit:
self.ratelimit(builder.user_id)
# TODO(paul): Why does 'event' not have a 'user' object?
user = UserID.from_string(builder.user_id)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if builder.type == EventTypes.Member: if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None) membership = builder.content.get("membership", None)
if membership == Membership.JOIN: if membership == Membership.JOIN:
@ -216,6 +213,25 @@ class MessageHandler(BaseHandler):
event, context = yield self._create_new_client_event( event, context = yield self._create_new_client_event(
builder=builder, builder=builder,
) )
defer.returnValue((event, context))
@defer.inlineCallbacks
def send_event(self, event, context, ratelimit=True, is_guest=False):
"""
Persists and notifies local clients and federation of an event.
Args:
event (FrozenEvent) the event to send.
context (Context) the context of the event.
ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest.
"""
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if ratelimit:
self.ratelimit(event.sender)
if event.is_state(): if event.is_state():
prev_state = context.current_state.get((event.type, event.state_key)) prev_state = context.current_state.get((event.type, event.state_key))
@ -229,7 +245,7 @@ class MessageHandler(BaseHandler):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context, is_guest=is_guest) yield member_handler.send_membership_event(event, context, is_guest=is_guest)
else: else:
yield self.handle_new_client_event( yield self.handle_new_client_event(
event=event, event=event,
@ -241,6 +257,25 @@ class MessageHandler(BaseHandler):
with PreserveLoggingContext(): with PreserveLoggingContext():
presence.bump_presence_active_time(user) presence.bump_presence_active_time(user)
@defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True,
token_id=None, txn_id=None, is_guest=False):
"""
Creates an event, then sends it.
See self.create_event and self.send_event.
"""
event, context = yield self.create_event(
event_dict,
token_id=token_id,
txn_id=txn_id
)
yield self.send_event(
event,
context,
ratelimit=ratelimit,
is_guest=is_guest
)
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -256,7 +291,7 @@ class MessageHandler(BaseHandler):
SynapseError if something went wrong. SynapseError if something went wrong.
""" """
membership, membership_event_id = yield self._check_in_room_or_world_readable( membership, membership_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id, is_guest room_id, user_id
) )
if membership == Membership.JOIN: if membership == Membership.JOIN:
@ -273,7 +308,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_in_room_or_world_readable(self, room_id, user_id, is_guest): def _check_in_room_or_world_readable(self, room_id, user_id):
try: try:
# check_user_was_in_room will return the most recent membership # check_user_was_in_room will return the most recent membership
# event for the user if: # event for the user if:
@ -283,7 +318,7 @@ class MessageHandler(BaseHandler):
member_event = yield self.auth.check_user_was_in_room(room_id, user_id) member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
defer.returnValue((member_event.membership, member_event.event_id)) defer.returnValue((member_event.membership, member_event.event_id))
return return
except AuthError, auth_error: except AuthError:
visibility = yield self.state_handler.get_current_state( visibility = yield self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, "" room_id, EventTypes.RoomHistoryVisibility, ""
) )
@ -293,8 +328,6 @@ class MessageHandler(BaseHandler):
): ):
defer.returnValue((Membership.JOIN, None)) defer.returnValue((Membership.JOIN, None))
return return
if not is_guest:
raise auth_error
raise AuthError( raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
) )
@ -312,7 +345,7 @@ class MessageHandler(BaseHandler):
A list of dicts representing state events. [{}, {}, {}] A list of dicts representing state events. [{}, {}, {}]
""" """
membership, membership_event_id = yield self._check_in_room_or_world_readable( membership, membership_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id, is_guest room_id, user_id
) )
if membership == Membership.JOIN: if membership == Membership.JOIN:
@ -523,13 +556,13 @@ class MessageHandler(BaseHandler):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def room_initial_sync(self, user_id, room_id, pagin_config=None, is_guest=False): def room_initial_sync(self, requester, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of """Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left. the room this will be what was in the room when they left.
Args: Args:
user_id(str): The user to get a snapshot for. requester(Requester): The user to get a snapshot for.
room_id(str): The room to get a snapshot of. room_id(str): The room to get a snapshot of.
pagin_config(synapse.streams.config.PaginationConfig): pagin_config(synapse.streams.config.PaginationConfig):
The pagination config used to determine how many messages to The pagination config used to determine how many messages to
@ -540,19 +573,20 @@ class MessageHandler(BaseHandler):
A JSON serialisable dict with the snapshot of the room. A JSON serialisable dict with the snapshot of the room.
""" """
user_id = requester.user.to_string()
membership, member_event_id = yield self._check_in_room_or_world_readable( membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, room_id, user_id,
user_id,
is_guest
) )
is_peeking = member_event_id is None
if membership == Membership.JOIN: if membership == Membership.JOIN:
result = yield self._room_initial_sync_joined( result = yield self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_guest user_id, room_id, pagin_config, membership, is_peeking
) )
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted( result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_guest user_id, room_id, pagin_config, membership, member_event_id, is_peeking
) )
account_data_events = [] account_data_events = []
@ -576,7 +610,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _room_initial_sync_parted(self, user_id, room_id, pagin_config, def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
membership, member_event_id, is_guest): membership, member_event_id, is_peeking):
room_state = yield self.store.get_state_for_events( room_state = yield self.store.get_state_for_events(
[member_event_id], None [member_event_id], None
) )
@ -598,7 +632,7 @@ class MessageHandler(BaseHandler):
) )
messages = yield self._filter_events_for_client( messages = yield self._filter_events_for_client(
user_id, messages, is_guest=is_guest user_id, messages, is_peeking=is_peeking
) )
start_token = StreamToken(token[0], 0, 0, 0, 0) start_token = StreamToken(token[0], 0, 0, 0, 0)
@ -621,7 +655,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _room_initial_sync_joined(self, user_id, room_id, pagin_config, def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
membership, is_guest): membership, is_peeking):
current_state = yield self.state.get_current_state( current_state = yield self.state.get_current_state(
room_id=room_id, room_id=room_id,
) )
@ -685,7 +719,7 @@ class MessageHandler(BaseHandler):
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client( messages = yield self._filter_events_for_client(
user_id, messages, is_guest=is_guest, require_all_visible_for_guests=False user_id, messages, is_peeking=is_peeking,
) )
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
@ -704,7 +738,7 @@ class MessageHandler(BaseHandler):
"presence": presence, "presence": presence,
"receipts": receipts, "receipts": receipts,
} }
if not is_guest: if not is_peeking:
ret["membership"] = membership ret["membership"] = membership
defer.returnValue(ret) defer.returnValue(ret)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -34,7 +34,7 @@ metrics = synapse.metrics.get_metrics_for(__name__)
# Don't bother bumping "last active" time if it differs by less than 60 seconds # Don't bother bumping "last active" time if it differs by less than 60 seconds
LAST_ACTIVE_GRANULARITY = 60*1000 LAST_ACTIVE_GRANULARITY = 60 * 1000
# Keep no more than this number of offline serial revisions # Keep no more than this number of offline serial revisions
MAX_OFFLINE_SERIALS = 1000 MAX_OFFLINE_SERIALS = 1000
@ -378,9 +378,9 @@ class PresenceHandler(BaseHandler):
was_polling = target_user in self._user_cachemap was_polling = target_user in self._user_cachemap
if now_online and not was_polling: if now_online and not was_polling:
self.start_polling_presence(target_user, state=state) yield self.start_polling_presence(target_user, state=state)
elif not now_online and was_polling: elif not now_online and was_polling:
self.stop_polling_presence(target_user) yield self.stop_polling_presence(target_user)
# TODO(paul): perform a presence push as part of start/stop poll so # TODO(paul): perform a presence push as part of start/stop poll so
# we don't have to do this all the time # we don't have to do this all the time
@ -394,6 +394,7 @@ class PresenceHandler(BaseHandler):
if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY: if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY:
return return
with PreserveLoggingContext():
self.changed_presencelike_data(user, {"last_active": now}) self.changed_presencelike_data(user, {"last_active": now})
def get_joined_rooms_for_user(self, user): def get_joined_rooms_for_user(self, user):
@ -466,6 +467,7 @@ class PresenceHandler(BaseHandler):
local_user, room_ids=[room_id], add_to_cache=False local_user, room_ids=[room_id], add_to_cache=False
) )
with PreserveLoggingContext():
self.push_update_to_local_and_remote( self.push_update_to_local_and_remote(
observed_user=local_user, observed_user=local_user,
users_to_push=[user], users_to_push=[user],
@ -556,7 +558,7 @@ class PresenceHandler(BaseHandler):
observer_user.localpart, observed_user.to_string() observer_user.localpart, observed_user.to_string()
) )
self.start_polling_presence( yield self.start_polling_presence(
observer_user, target_user=observed_user observer_user, target_user=observed_user
) )

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -21,7 +21,6 @@ from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
) )
from ._base import BaseHandler from ._base import BaseHandler
import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
@ -40,19 +39,22 @@ class RegistrationHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(RegistrationHandler, self).__init__(hs) super(RegistrationHandler, self).__init__(hs)
self.auth = hs.get_auth()
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
self.distributor.declare("registered_user") self.distributor.declare("registered_user")
self.captcha_client = CaptchaServerHttpClient(hs) self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None
@defer.inlineCallbacks @defer.inlineCallbacks
def check_username(self, localpart): def check_username(self, localpart, guest_access_token=None):
yield run_on_reactor() yield run_on_reactor()
if urllib.quote(localpart) != localpart: if urllib.quote(localpart.encode('utf-8')) != localpart:
raise SynapseError( raise SynapseError(
400, 400,
"User ID must only contain characters which do not" "User ID can only contain characters a-z, 0-9, or '_-./'",
" require URL encoding." Codes.INVALID_USERNAME
) )
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
@ -62,19 +64,35 @@ class RegistrationHandler(BaseHandler):
users = yield self.store.get_users_by_id_case_insensitive(user_id) users = yield self.store.get_users_by_id_case_insensitive(user_id)
if users: if users:
if not guest_access_token:
raise SynapseError( raise SynapseError(
400, 400,
"User ID already taken.", "User ID already taken.",
errcode=Codes.USER_IN_USE, errcode=Codes.USER_IN_USE,
) )
user_data = yield self.auth.get_user_from_macaroon(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
raise AuthError(
403,
"Cannot register taken user ID without valid guest "
"credentials for that user.",
errcode=Codes.FORBIDDEN,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def register(self, localpart=None, password=None, generate_token=True): def register(
self,
localpart=None,
password=None,
generate_token=True,
guest_access_token=None,
make_guest=False
):
"""Registers a new client on the server. """Registers a new client on the server.
Args: Args:
localpart : The local part of the user ID to register. If None, localpart : The local part of the user ID to register. If None,
one will be randomly generated. one will be generated.
password (str) : The password to assign to this user so they can password (str) : The password to assign to this user so they can
login again. This can be None which means they cannot login again login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user). via a password (e.g. the user is an application service user).
@ -89,7 +107,19 @@ class RegistrationHandler(BaseHandler):
password_hash = self.auth_handler().hash(password) password_hash = self.auth_handler().hash(password)
if localpart: if localpart:
yield self.check_username(localpart) yield self.check_username(localpart, guest_access_token=guest_access_token)
was_guest = guest_access_token is not None
if not was_guest:
try:
int(localpart)
raise RegistrationError(
400,
"Numeric user IDs are reserved for guest users."
)
except ValueError:
pass
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
@ -100,37 +130,37 @@ class RegistrationHandler(BaseHandler):
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash password_hash=password_hash,
was_guest=was_guest,
make_guest=make_guest,
) )
yield registered_user(self.distributor, user) yield registered_user(self.distributor, user)
else: else:
# autogen a random user ID # autogen a sequential user ID
attempts = 0 attempts = 0
user_id = None
token = None token = None
while not user_id: user = None
try: while not user:
localpart = self._generate_user_id() localpart = yield self._generate_user_id(attempts > 0)
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_is_valid(user_id)
if generate_token: if generate_token:
token = self.auth_handler().generate_access_token(user_id) token = self.auth_handler().generate_access_token(user_id)
try:
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash) password_hash=password_hash,
make_guest=make_guest
yield registered_user(self.distributor, user) )
except SynapseError: except SynapseError:
# if user id is taken, just generate another # if user id is taken, just generate another
user_id = None user_id = None
token = None token = None
attempts += 1 attempts += 1
if attempts > 5: yield registered_user(self.distributor, user)
raise RegistrationError(
500, "Cannot generate user ID.")
# We used to generate default identicons here, but nowadays # We used to generate default identicons here, but nowadays
# we want clients to generate their own as part of their branding # we want clients to generate their own as part of their branding
@ -156,7 +186,7 @@ class RegistrationHandler(BaseHandler):
token=token, token=token,
password_hash="" password_hash=""
) )
registered_user(self.distributor, user) yield registered_user(self.distributor, user)
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -262,8 +292,16 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
def _generate_user_id(self): @defer.inlineCallbacks
return "-" + stringutils.random_string(18) def _generate_user_id(self, reseed=False):
if reseed or self._next_generated_user_id is None:
self._next_generated_user_id = (
yield self.store.find_next_generated_user_id_localpart()
)
id = self._next_generated_user_id
self._next_generated_user_id += 1
defer.returnValue(str(id))
@defer.inlineCallbacks @defer.inlineCallbacks
def _validate_captcha(self, ip_addr, private_key, challenge, response): def _validate_captcha(self, ip_addr, private_key, challenge, response):

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,13 +18,14 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset, EventTypes, Membership, JoinRules, RoomCreationPreset,
) )
from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.api.errors import AuthError, StoreError, SynapseError, Codes
from synapse.util import stringutils, unwrapFirstError from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn
from signedjson.sign import verify_signed_json from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -46,11 +47,17 @@ def collect_presencelike_data(distributor, user, content):
def user_left_room(distributor, user, room_id): def user_left_room(distributor, user, room_id):
return distributor.fire("user_left_room", user=user, room_id=room_id) return preserve_context_over_fn(
distributor.fire,
"user_left_room", user=user, room_id=room_id
)
def user_joined_room(distributor, user, room_id): def user_joined_room(distributor, user, room_id):
return distributor.fire("user_joined_room", user=user, room_id=room_id) return preserve_context_over_fn(
distributor.fire,
"user_joined_room", user=user, room_id=room_id
)
class RoomCreationHandler(BaseHandler): class RoomCreationHandler(BaseHandler):
@ -115,6 +122,8 @@ class RoomCreationHandler(BaseHandler):
except: except:
raise SynapseError(400, "Invalid user_id: %s" % (i,)) raise SynapseError(400, "Invalid user_id: %s" % (i,))
invite_3pid_list = config.get("invite_3pid", [])
is_public = config.get("visibility", None) == "public" is_public = config.get("visibility", None) == "public"
if room_id: if room_id:
@ -220,6 +229,20 @@ class RoomCreationHandler(BaseHandler):
"content": {"membership": Membership.INVITE}, "content": {"membership": Membership.INVITE},
}, ratelimit=False) }, ratelimit=False)
for invite_3pid in invite_3pid_list:
id_server = invite_3pid["id_server"]
address = invite_3pid["address"]
medium = invite_3pid["medium"]
yield self.hs.get_handlers().room_member_handler.do_3pid_invite(
room_id,
user,
medium,
address,
id_server,
token_id=None,
txn_id=None,
)
result = {"room_id": room_id} result = {"room_id": room_id}
if room_alias: if room_alias:
@ -381,7 +404,58 @@ class RoomMemberHandler(BaseHandler):
remotedomains.add(member.domain) remotedomains.add(member.domain)
@defer.inlineCallbacks @defer.inlineCallbacks
def change_membership(self, event, context, do_auth=True, is_guest=False): def update_membership(self, requester, target, room_id, action, txn_id=None):
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
elif action == "forget":
effective_membership_state = "leave"
msg_handler = self.hs.get_handlers().message_handler
content = {"membership": unicode(effective_membership_state)}
if requester.is_guest:
content["kind"] = "guest"
event, context = yield msg_handler.create_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"state_key": target.to_string(),
},
token_id=requester.access_token_id,
txn_id=txn_id,
)
old_state = context.current_state.get((EventTypes.Member, event.state_key))
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
403,
"Cannot unban user who was not banned (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE
)
if old_membership == "ban" and action != "unban":
raise SynapseError(
403,
"Cannot %s user who was is banned" % (action,),
errcode=Codes.BAD_STATE
)
yield msg_handler.send_event(
event,
context,
ratelimit=True,
is_guest=requester.is_guest
)
if action == "forget":
yield self.forget(requester.user, room_id)
@defer.inlineCallbacks
def send_membership_event(self, event, context, is_guest=False):
""" Change the membership status of a user in a room. """ Change the membership status of a user in a room.
Args: Args:
@ -416,7 +490,7 @@ class RoomMemberHandler(BaseHandler):
if not is_guest_access_allowed: if not is_guest_access_allowed:
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
yield self._do_join(event, context, do_auth=do_auth) yield self._do_join(event, context)
else: else:
if event.membership == Membership.LEAVE: if event.membership == Membership.LEAVE:
is_host_in_room = yield self.is_host_in_room(room_id, context) is_host_in_room = yield self.is_host_in_room(room_id, context)
@ -443,9 +517,7 @@ class RoomMemberHandler(BaseHandler):
yield self._do_local_membership_update( yield self._do_local_membership_update(
event, event,
membership=event.content["membership"],
context=context, context=context,
do_auth=do_auth,
) )
if prev_state and prev_state.membership == Membership.JOIN: if prev_state and prev_state.membership == Membership.JOIN:
@ -481,12 +553,12 @@ class RoomMemberHandler(BaseHandler):
}) })
event, context = yield self._create_new_client_event(builder) event, context = yield self._create_new_client_event(builder)
yield self._do_join(event, context, room_hosts=hosts, do_auth=True) yield self._do_join(event, context, room_hosts=hosts)
defer.returnValue({"room_id": room_id}) defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_join(self, event, context, room_hosts=None, do_auth=True): def _do_join(self, event, context, room_hosts=None):
room_id = event.room_id room_id = event.room_id
# XXX: We don't do an auth check if we are doing an invite # XXX: We don't do an auth check if we are doing an invite
@ -520,9 +592,7 @@ class RoomMemberHandler(BaseHandler):
yield self._do_local_membership_update( yield self._do_local_membership_update(
event, event,
membership=event.content["membership"],
context=context, context=context,
do_auth=do_auth,
) )
prev_state = context.current_state.get((event.type, event.state_key)) prev_state = context.current_state.get((event.type, event.state_key))
@ -587,8 +657,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(room_ids) defer.returnValue(room_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_local_membership_update(self, event, membership, context, def _do_local_membership_update(self, event, context):
do_auth):
yield run_on_reactor() yield run_on_reactor()
target_user = UserID.from_string(event.state_key) target_user = UserID.from_string(event.state_key)
@ -597,7 +666,6 @@ class RoomMemberHandler(BaseHandler):
event, event,
context, context,
extra_users=[target_user], extra_users=[target_user],
suppress_auth=(not do_auth),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -815,39 +883,71 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_public_room_list(self): def get_public_room_list(self):
chunk = yield self.store.get_rooms(is_public=True) room_ids = yield self.store.get_public_room_ids()
room_members = yield defer.gatherResults(
[
self.store.get_users_in_room(room["room_id"])
for room in chunk
],
consumeErrors=True,
).addErrback(unwrapFirstError)
avatar_urls = yield defer.gatherResults(
[
self.get_room_avatar_url(room["room_id"])
for room in chunk
],
consumeErrors=True,
).addErrback(unwrapFirstError)
for i, room in enumerate(chunk):
room["num_joined_members"] = len(room_members[i])
if avatar_urls[i]:
room["avatar_url"] = avatar_urls[i]
# FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": chunk})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_avatar_url(self, room_id): def handle_room(room_id):
event = yield self.hs.get_state_handler().get_current_state( aliases = yield self.store.get_aliases_for_room(room_id)
room_id, "m.room.avatar" if not aliases:
defer.returnValue(None)
state = yield self.state_handler.get_current_state(room_id)
result = {"aliases": aliases, "room_id": room_id}
name_event = state.get((EventTypes.Name, ""), None)
if name_event:
name = name_event.content.get("name", None)
if name:
result["name"] = name
topic_event = state.get((EventTypes.Topic, ""), None)
if topic_event:
topic = topic_event.content.get("topic", None)
if topic:
result["topic"] = topic
canonical_event = state.get((EventTypes.CanonicalAlias, ""), None)
if canonical_event:
canonical_alias = canonical_event.content.get("alias", None)
if canonical_alias:
result["canonical_alias"] = canonical_alias
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
visibility = None
if visibility_event:
visibility = visibility_event.content.get("history_visibility", None)
result["world_readable"] = visibility == "world_readable"
guest_event = state.get((EventTypes.GuestAccess, ""), None)
guest = None
if guest_event:
guest = guest_event.content.get("guest_access", None)
result["guest_can_join"] = guest == "can_join"
avatar_event = state.get(("m.room.avatar", ""), None)
if avatar_event:
avatar_url = avatar_event.content.get("url", None)
if avatar_url:
result["avatar_url"] = avatar_url
result["num_joined_members"] = sum(
1 for (event_type, _), ev in state.items()
if event_type == EventTypes.Member and ev.membership == Membership.JOIN
) )
if event and "url" in event.content:
defer.returnValue(event.content["url"]) defer.returnValue(result)
result = []
for chunk in (room_ids[i:i + 10] for i in xrange(0, len(room_ids), 10)):
chunk_result = yield defer.gatherResults([
handle_room(room_id)
for room_id in chunk
], consumeErrors=True).addErrback(unwrapFirstError)
result.extend(v for v in chunk_result if v)
# FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": result})
class RoomContextHandler(BaseHandler): class RoomContextHandler(BaseHandler):
@ -864,30 +964,39 @@ class RoomContextHandler(BaseHandler):
(excluding state). (excluding state).
Returns: Returns:
dict dict, or None if the event isn't found
""" """
before_limit = math.floor(limit/2.) before_limit = math.floor(limit / 2.)
after_limit = limit - before_limit after_limit = limit - before_limit
now_token = yield self.hs.get_event_sources().get_current_token() now_token = yield self.hs.get_event_sources().get_current_token()
def filter_evts(events):
return self._filter_events_for_client(
user.to_string(),
events,
is_peeking=is_guest)
event = yield self.store.get_event(event_id, get_prev_content=True,
allow_none=True)
if not event:
defer.returnValue(None)
return
filtered = yield(filter_evts([event]))
if not filtered:
raise AuthError(
403,
"You don't have permission to access that event."
)
results = yield self.store.get_events_around( results = yield self.store.get_events_around(
room_id, event_id, before_limit, after_limit room_id, event_id, before_limit, after_limit
) )
results["events_before"] = yield self._filter_events_for_client( results["events_before"] = yield filter_evts(results["events_before"])
user.to_string(), results["events_after"] = yield filter_evts(results["events_after"])
results["events_before"], results["event"] = event
is_guest=is_guest,
require_all_visible_for_guests=False
)
results["events_after"] = yield self._filter_events_for_client(
user.to_string(),
results["events_after"],
is_guest=is_guest,
require_all_visible_for_guests=False
)
if results["events_after"]: if results["events_after"]:
last_event_id = results["events_after"][-1].event_id last_event_id = results["events_after"][-1].event_id
@ -927,6 +1036,11 @@ class RoomEventSource(object):
to_key = yield self.get_current_key() to_key = yield self.get_current_key()
from_token = RoomStreamToken.parse(from_key)
if from_token.topological:
logger.warn("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,)
app_service = yield self.store.get_app_service_by_user_id( app_service = yield self.store.get_app_service_by_user_id(
user.to_string() user.to_string()
) )
@ -938,15 +1052,30 @@ class RoomEventSource(object):
limit=limit, limit=limit,
) )
else: else:
events, end_key = yield self.store.get_room_events_stream( room_events = yield self.store.get_membership_changes_for_user(
user_id=user.to_string(), user.to_string(), from_key, to_key
)
room_to_events = yield self.store.get_room_events_stream_for_rooms(
room_ids=room_ids,
from_key=from_key, from_key=from_key,
to_key=to_key, to_key=to_key,
limit=limit, limit=limit or 10,
room_ids=room_ids,
is_guest=is_guest,
) )
events = list(room_events)
events.extend(e for evs, _ in room_to_events.values() for e in evs)
events.sort(key=lambda e: e.internal_metadata.order)
if limit:
events[:] = events[:limit]
if events:
end_key = events[-1].internal_metadata.after
else:
end_key = to_key
defer.returnValue((events, end_key)) defer.returnValue((events, end_key))
def get_current_key(self, direction='f'): def get_current_key(self, direction='f'):

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015 - 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,22 +15,25 @@
from ._base import BaseHandler from ._base import BaseHandler
from synapse.streams.config import PaginationConfig
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.api.errors import GuestAccessError
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.metrics import Measure
from twisted.internet import defer from twisted.internet import defer
import collections import collections
import logging import logging
import itertools
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SyncConfig = collections.namedtuple("SyncConfig", [ SyncConfig = collections.namedtuple("SyncConfig", [
"user", "user",
"filter_collection",
"is_guest", "is_guest",
"filter",
]) ])
@ -54,6 +57,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"state", # dict[(str, str), FrozenEvent] "state", # dict[(str, str), FrozenEvent]
"ephemeral", "ephemeral",
"account_data", "account_data",
"unread_notifications",
])): ])):
__slots__ = [] __slots__ = []
@ -66,10 +70,12 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
or self.state or self.state
or self.ephemeral or self.ephemeral
or self.account_data or self.account_data
# nb the notification count does not, er, count: if there's nothing
# else in the result, we don't need to send it.
) )
class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [ class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
"room_id", # str "room_id", # str
"timeline", # TimelineBatch "timeline", # TimelineBatch
"state", # dict[(str, str), FrozenEvent] "state", # dict[(str, str), FrozenEvent]
@ -115,11 +121,9 @@ class SyncResult(collections.namedtuple("SyncResult", [
events. events.
""" """
return bool( return bool(
self.presence or self.joined or self.invited self.presence or self.joined or self.invited or self.archived
) )
GuestRoom = collections.namedtuple("GuestRoom", ("room_id", "membership"))
class SyncHandler(BaseHandler): class SyncHandler(BaseHandler):
@ -138,45 +142,32 @@ class SyncHandler(BaseHandler):
A Deferred SyncResult. A Deferred SyncResult.
""" """
if sync_config.is_guest: context = LoggingContext.current_context()
bad_rooms = [] if context:
for room_id in sync_config.filter.list_rooms(): if since_token is None:
world_readable = yield self._is_world_readable(room_id) context.tag = "initial_sync"
if not world_readable: elif full_state:
bad_rooms.append(room_id) context.tag = "full_state_sync"
else:
if bad_rooms: context.tag = "incremental_sync"
raise GuestAccessError(
bad_rooms, 403, "Guest access not allowed"
)
if timeout == 0 or since_token is None or full_state: if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling # we are going to return immediately, so don't bother calling
# notifier.wait_for_events. # notifier.wait_for_events.
result = yield self.current_sync_for_user(sync_config, since_token, result = yield self.current_sync_for_user(
full_state=full_state) sync_config, since_token, full_state=full_state,
)
defer.returnValue(result) defer.returnValue(result)
else: else:
def current_sync_callback(before_token, after_token): def current_sync_callback(before_token, after_token):
return self.current_sync_for_user(sync_config, since_token) return self.current_sync_for_user(sync_config, since_token)
result = yield self.notifier.wait_for_events( result = yield self.notifier.wait_for_events(
sync_config.user, timeout, current_sync_callback, sync_config.user.to_string(), timeout, current_sync_callback,
from_token=since_token from_token=since_token,
) )
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def _is_world_readable(self, room_id):
state = yield self.hs.get_state_handler().get_current_state(
room_id,
EventTypes.RoomHistoryVisibility
)
if state and "history_visibility" in state.content:
defer.returnValue(state.content["history_visibility"] == "world_readable")
else:
defer.returnValue(False)
def current_sync_for_user(self, sync_config, since_token=None, def current_sync_for_user(self, sync_config, since_token=None,
full_state=False): full_state=False):
"""Get the sync for client needed to match what the server has now. """Get the sync for client needed to match what the server has now.
@ -200,19 +191,22 @@ class SyncHandler(BaseHandler):
""" """
now_token = yield self.event_sources.get_current_token() now_token = yield self.event_sources.get_current_token()
if sync_config.is_guest: now_token, ephemeral_by_room = yield self.ephemeral_by_room(
room_list = [ sync_config, now_token
GuestRoom(room_id, Membership.JOIN) )
for room_id in sync_config.filter.list_rooms()
]
account_data = {} presence_stream = self.event_sources.sources["presence"]
account_data_by_room = {} # TODO (mjark): This looks wrong, shouldn't we be getting the presence
tags_by_room = {} # UP to the present rather than after the present?
pagination_config = PaginationConfig(from_token=now_token)
presence, _ = yield presence_stream.get_pagination_rows(
user=sync_config.user,
pagination_config=pagination_config.get_source_config("presence"),
key=None
)
else:
membership_list = (Membership.INVITE, Membership.JOIN) membership_list = (Membership.INVITE, Membership.JOIN)
if sync_config.filter.include_leave: if sync_config.filter_collection.include_leave:
membership_list += (Membership.LEAVE, Membership.BAN) membership_list += (Membership.LEAVE, Membership.BAN)
room_list = yield self.store.get_rooms_for_user_where_membership_is( room_list = yield self.store.get_rooms_for_user_where_membership_is(
@ -230,31 +224,18 @@ class SyncHandler(BaseHandler):
sync_config.user.to_string() sync_config.user.to_string()
) )
presence_stream = self.event_sources.sources["presence"]
joined_room_ids = [
room.room_id for room in room_list
if room.membership == Membership.JOIN
]
presence, _ = yield presence_stream.get_new_events(
from_key=0,
user=sync_config.user,
room_ids=joined_room_ids,
is_guest=sync_config.is_guest,
)
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token, joined_room_ids
)
joined = [] joined = []
invited = [] invited = []
archived = [] archived = []
deferreds = [] deferreds = []
for event in room_list:
room_list_chunks = [room_list[i:i + 10] for i in xrange(0, len(room_list), 10)]
for room_list_chunk in room_list_chunks:
for event in room_list_chunk:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
room_sync_deferred = self.full_state_sync_for_joined_room( room_sync_deferred = preserve_fn(
self.full_state_sync_for_joined_room
)(
room_id=event.room_id, room_id=event.room_id,
sync_config=sync_config, sync_config=sync_config,
now_token=now_token, now_token=now_token,
@ -275,7 +256,9 @@ class SyncHandler(BaseHandler):
leave_token = now_token.copy_and_replace( leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,) "room_key", "s%d" % (event.stream_ordering,)
) )
room_sync_deferred = self.full_state_sync_for_archived_room( room_sync_deferred = preserve_fn(
self.full_state_sync_for_archived_room
)(
sync_config=sync_config, sync_config=sync_config,
room_id=event.room_id, room_id=event.room_id,
leave_event_id=event.event_id, leave_event_id=event.event_id,
@ -291,9 +274,17 @@ class SyncHandler(BaseHandler):
deferreds, consumeErrors=True deferreds, consumeErrors=True
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data)
)
presence = sync_config.filter_collection.filter_presence(
presence
)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=presence, presence=presence,
account_data=self.account_data_for_user(account_data), account_data=account_data_for_user,
joined=joined, joined=joined,
invited=invited, invited=invited,
archived=archived, archived=archived,
@ -314,17 +305,18 @@ class SyncHandler(BaseHandler):
room_id, sync_config, now_token, since_token=timeline_since_token room_id, sync_config, now_token, since_token=timeline_since_token
) )
current_state = yield self.get_state_at(room_id, now_token) room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config,
now_token=now_token,
since_token=timeline_since_token,
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
batch=batch,
full_state=True,
)
defer.returnValue(JoinedSyncResult( defer.returnValue(room_sync)
room_id=room_id,
timeline=batch,
state=current_state,
ephemeral=ephemeral_by_room.get(room_id, []),
account_data=self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
),
))
def account_data_for_user(self, account_data): def account_data_for_user(self, account_data):
account_data_events = [] account_data_events = []
@ -356,13 +348,11 @@ class SyncHandler(BaseHandler):
return account_data_events return account_data_events
@defer.inlineCallbacks @defer.inlineCallbacks
def ephemeral_by_room(self, sync_config, now_token, room_ids, def ephemeral_by_room(self, sync_config, now_token, since_token=None):
since_token=None):
"""Get the ephemeral events for each room the user is in """Get the ephemeral events for each room the user is in
Args: Args:
sync_config (SyncConfig): The flags, filters and user for the sync. sync_config (SyncConfig): The flags, filters and user for the sync.
now_token (StreamToken): Where the server is currently up to. now_token (StreamToken): Where the server is currently up to.
room_ids (list): List of room id strings to get data for.
since_token (StreamToken): Where the server was when the client since_token (StreamToken): Where the server was when the client
last synced. last synced.
Returns: Returns:
@ -371,15 +361,19 @@ class SyncHandler(BaseHandler):
typing events for that room. typing events for that room.
""" """
with Measure(self.clock, "ephemeral_by_room"):
typing_key = since_token.typing_key if since_token else "0" typing_key = since_token.typing_key if since_token else "0"
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
room_ids = [room.room_id for room in rooms]
typing_source = self.event_sources.sources["typing"] typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events( typing, typing_key = yield typing_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=typing_key, from_key=typing_key,
limit=sync_config.filter.ephemeral_limit(), limit=sync_config.filter_collection.ephemeral_limit(),
room_ids=room_ids, room_ids=room_ids,
is_guest=False, is_guest=sync_config.is_guest,
) )
now_token = now_token.copy_and_replace("typing_key", typing_key) now_token = now_token.copy_and_replace("typing_key", typing_key)
@ -400,10 +394,9 @@ class SyncHandler(BaseHandler):
receipts, receipt_key = yield receipt_source.get_new_events( receipts, receipt_key = yield receipt_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=receipt_key, from_key=receipt_key,
limit=sync_config.filter.ephemeral_limit(), limit=sync_config.filter_collection.ephemeral_limit(),
room_ids=room_ids, room_ids=room_ids,
# /sync doesn't support guest access, they can't get to this point in code is_guest=sync_config.is_guest,
is_guest=False,
) )
now_token = now_token.copy_and_replace("receipt_key", receipt_key) now_token = now_token.copy_and_replace("receipt_key", receipt_key)
@ -416,31 +409,20 @@ class SyncHandler(BaseHandler):
defer.returnValue((now_token, ephemeral_by_room)) defer.returnValue((now_token, ephemeral_by_room))
@defer.inlineCallbacks
def full_state_sync_for_archived_room(self, room_id, sync_config, def full_state_sync_for_archived_room(self, room_id, sync_config,
leave_event_id, leave_token, leave_event_id, leave_token,
timeline_since_token, tags_by_room, timeline_since_token, tags_by_room,
account_data_by_room): account_data_by_room):
"""Sync a room for a client which is starting without any state """Sync a room for a client which is starting without any state
Returns: Returns:
A Deferred JoinedSyncResult. A Deferred ArchivedSyncResult.
""" """
batch = yield self.load_filtered_recents( return self.incremental_sync_for_archived_room(
room_id, sync_config, leave_token, since_token=timeline_since_token sync_config, room_id, leave_event_id, timeline_since_token, tags_by_room,
account_data_by_room, full_state=True, leave_token=leave_token,
) )
leave_state = yield self.store.get_state_for_event(leave_event_id)
defer.returnValue(ArchivedSyncResult(
room_id=room_id,
timeline=batch,
state=leave_state,
account_data=self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
),
))
@defer.inlineCallbacks @defer.inlineCallbacks
def incremental_sync_with_gap(self, sync_config, since_token): def incremental_sync_with_gap(self, sync_config, since_token):
""" Get the incremental delta needed to bring the client up to """ Get the incremental delta needed to bring the client up to
@ -450,49 +432,23 @@ class SyncHandler(BaseHandler):
""" """
now_token = yield self.event_sources.get_current_token() now_token = yield self.event_sources.get_current_token()
if sync_config.is_guest: rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
room_ids = sync_config.filter.list_rooms()
tags_by_room = {}
account_data = {}
account_data_by_room = {}
else:
rooms = yield self.store.get_rooms_for_user(
sync_config.user.to_string()
)
room_ids = [room.room_id for room in rooms] room_ids = [room.room_id for room in rooms]
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token, since_token
)
tags_by_room = yield self.store.get_updated_tags(
sync_config.user.to_string(),
since_token.account_data_key,
)
account_data, account_data_by_room = (
yield self.store.get_updated_account_data_for_user(
sync_config.user.to_string(),
since_token.account_data_key,
)
)
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token, room_ids, since_token
)
presence_source = self.event_sources.sources["presence"] presence_source = self.event_sources.sources["presence"]
presence, presence_key = yield presence_source.get_new_events( presence, presence_key = yield presence_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=since_token.presence_key, from_key=since_token.presence_key,
limit=sync_config.filter.presence_limit(), limit=sync_config.filter_collection.presence_limit(),
room_ids=room_ids, room_ids=room_ids,
is_guest=sync_config.is_guest, is_guest=sync_config.is_guest,
) )
now_token = now_token.copy_and_replace("presence_key", presence_key) now_token = now_token.copy_and_replace("presence_key", presence_key)
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token, since_token
)
rm_handler = self.hs.get_handlers().room_member_handler rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id( app_service = yield self.store.get_app_service_by_user_id(
sync_config.user.to_string() sync_config.user.to_string()
@ -505,114 +461,138 @@ class SyncHandler(BaseHandler):
sync_config.user sync_config.user
) )
timeline_limit = sync_config.filter.timeline_limit() user_id = sync_config.user.to_string()
room_events, _ = yield self.store.get_room_events_stream( timeline_limit = sync_config.filter_collection.timeline_limit()
sync_config.user.to_string(),
tags_by_room = yield self.store.get_updated_tags(
user_id,
since_token.account_data_key,
)
account_data, account_data_by_room = (
yield self.store.get_updated_account_data_for_user(
user_id,
since_token.account_data_key,
)
)
# Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
)
mem_change_events_by_room_id = {}
for event in rooms_changed:
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
newly_joined_rooms = []
archived = []
invited = []
for room_id, events in mem_change_events_by_room_id.items():
non_joins = [e for e in events if e.membership != Membership.JOIN]
has_join = len(non_joins) != len(events)
# We want to figure out if we joined the room at some point since
# the last sync (even if we have since left). This is to make sure
# we do send down the room, and with full state, where necessary
if room_id in joined_room_ids or has_join:
old_state = yield self.get_state_at(room_id, since_token)
old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
newly_joined_rooms.append(room_id)
if room_id in joined_room_ids:
continue
if not non_joins:
continue
# Only bother if we're still currently invited
should_invite = non_joins[-1].membership == Membership.INVITE
if should_invite:
room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
if room_sync:
invited.append(room_sync)
# Always include leave/ban events. Just take the last one.
# TODO: How do we handle ban -> leave in same batch?
leave_events = [
e for e in non_joins
if e.membership in (Membership.LEAVE, Membership.BAN)
]
if leave_events:
leave_event = leave_events[-1]
room_sync = yield self.incremental_sync_for_archived_room(
sync_config, room_id, leave_event.event_id, since_token,
tags_by_room, account_data_by_room,
full_state=room_id in newly_joined_rooms
)
if room_sync:
archived.append(room_sync)
# Get all events for rooms we're currently joined to.
room_to_events = yield self.store.get_room_events_stream_for_rooms(
room_ids=joined_room_ids,
from_key=since_token.room_key, from_key=since_token.room_key,
to_key=now_token.room_key, to_key=now_token.room_key,
limit=timeline_limit + 1, limit=timeline_limit + 1,
room_ids=room_ids if sync_config.is_guest else (),
is_guest=sync_config.is_guest,
) )
joined = [] joined = []
archived = [] # We loop through all room ids, even if there are no new events, in case
if len(room_events) <= timeline_limit: # there are non room events taht we need to notify about.
# There is no gap in any of the rooms. Therefore we can just
# partition the new events by room and return them.
logger.debug("Got %i events for incremental sync - not limited",
len(room_events))
invite_events = []
leave_events = []
events_by_room_id = {}
for event in room_events:
events_by_room_id.setdefault(event.room_id, []).append(event)
if event.room_id not in joined_room_ids:
if (event.type == EventTypes.Member
and event.state_key == sync_config.user.to_string()):
if event.membership == Membership.INVITE:
invite_events.append(event)
elif event.membership in (Membership.LEAVE, Membership.BAN):
leave_events.append(event)
for room_id in joined_room_ids: for room_id in joined_room_ids:
recents = events_by_room_id.get(room_id, []) room_entry = room_to_events.get(room_id, None)
logger.debug("Events for room %s: %r", room_id, recents)
state = {
(event.type, event.state_key): event
for event in recents if event.is_state()}
limited = False
if recents: if room_entry:
prev_batch = now_token.copy_and_replace( events, start_key = room_entry
"room_key", recents[0].internal_metadata.before
prev_batch_token = now_token.copy_and_replace("room_key", start_key)
newly_joined_room = room_id in newly_joined_rooms
full_state = newly_joined_room
batch = yield self.load_filtered_recents(
room_id, sync_config, prev_batch_token,
since_token=since_token,
recents=events,
newly_joined_room=newly_joined_room,
) )
else: else:
prev_batch = now_token batch = TimelineBatch(
events=[],
just_joined = yield self.check_joined_room(sync_config, state) prev_batch=since_token,
if just_joined: limited=False,
logger.debug("User has just joined %s: needs full state",
room_id)
state = yield self.get_state_at(room_id, now_token)
# the timeline is inherently limited if we've just joined
limited = True
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=TimelineBatch(
events=recents,
prev_batch=prev_batch,
limited=limited,
),
state=state,
ephemeral=ephemeral_by_room.get(room_id, []),
account_data=self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
),
) )
logger.debug("Result for room %s: %r", room_id, room_sync) full_state = False
if room_sync:
joined.append(room_sync)
else:
logger.debug("Got %i events for incremental sync - hit limit",
len(room_events))
invite_events = yield self.store.get_invites_for_user(
sync_config.user.to_string()
)
leave_events = yield self.store.get_leave_and_ban_events_for_user(
sync_config.user.to_string()
)
for room_id in joined_room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room( room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token, room_id=room_id,
ephemeral_by_room, tags_by_room, account_data_by_room sync_config=sync_config,
since_token=since_token,
now_token=now_token,
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
batch=batch,
full_state=full_state,
) )
if room_sync: if room_sync:
joined.append(room_sync) joined.append(room_sync)
for leave_event in leave_events: account_data_for_user = sync_config.filter_collection.filter_account_data(
room_sync = yield self.incremental_sync_for_archived_room( self.account_data_for_user(account_data)
sync_config, leave_event, since_token, tags_by_room,
account_data_by_room
) )
archived.append(room_sync)
invited = [ presence = sync_config.filter_collection.filter_presence(
InvitedSyncResult(room_id=event.room_id, invite=event) presence
for event in invite_events )
]
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=presence, presence=presence,
account_data=self.account_data_for_user(account_data), account_data=account_data_for_user,
joined=joined, joined=joined,
invited=invited, invited=invited,
archived=archived, archived=archived,
@ -621,39 +601,58 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token, def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None): since_token=None, recents=None, newly_joined_room=False):
""" """
:returns a Deferred TimelineBatch :returns a Deferred TimelineBatch
""" """
limited = True with Measure(self.clock, "load_filtered_recents"):
recents = []
filtering_factor = 2 filtering_factor = 2
timeline_limit = sync_config.filter.timeline_limit() timeline_limit = sync_config.filter_collection.timeline_limit()
load_limit = max(timeline_limit * filtering_factor, 100) load_limit = max(timeline_limit * filtering_factor, 10)
max_repeat = 3 # Only try a few times per room, otherwise max_repeat = 5 # Only try a few times per room, otherwise
room_key = now_token.room_key room_key = now_token.room_key
end_key = room_key end_key = room_key
if recents is None or newly_joined_room or timeline_limit < len(recents):
limited = True
else:
limited = False
if recents is not None:
recents = sync_config.filter_collection.filter_room_timeline(recents)
recents = yield self._filter_events_for_client(
sync_config.user.to_string(),
recents,
is_peeking=sync_config.is_guest,
)
else:
recents = []
since_key = None
if since_token and not newly_joined_room:
since_key = since_token.room_key
while limited and len(recents) < timeline_limit and max_repeat: while limited and len(recents) < timeline_limit and max_repeat:
events, keys = yield self.store.get_recent_events_for_room( events, end_key = yield self.store.get_room_events_stream_for_room(
room_id, room_id,
limit=load_limit + 1, limit=load_limit + 1,
from_token=since_token.room_key if since_token else None, from_key=since_key,
end_token=end_key, to_key=end_key,
)
loaded_recents = sync_config.filter_collection.filter_room_timeline(
events
) )
(room_key, _) = keys
end_key = "s" + room_key.split('-')[-1]
loaded_recents = sync_config.filter.filter_room_timeline(events)
loaded_recents = yield self._filter_events_for_client( loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(), sync_config.user.to_string(),
loaded_recents, loaded_recents,
is_guest=sync_config.is_guest, is_peeking=sync_config.is_guest,
require_all_visible_for_guests=False
) )
loaded_recents.extend(recents) loaded_recents.extend(recents)
recents = loaded_recents recents = loaded_recents
if len(events) <= load_limit: if len(events) <= load_limit:
limited = False limited = False
break
max_repeat -= 1 max_repeat -= 1
if len(recents) > timeline_limit: if len(recents) > timeline_limit:
@ -666,107 +665,105 @@ class SyncHandler(BaseHandler):
) )
defer.returnValue(TimelineBatch( defer.returnValue(TimelineBatch(
events=recents, prev_batch=prev_batch_token, limited=limited events=recents,
prev_batch=prev_batch_token,
limited=limited or newly_joined_room
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def incremental_sync_with_gap_for_room(self, room_id, sync_config, def incremental_sync_with_gap_for_room(self, room_id, sync_config,
since_token, now_token, since_token, now_token,
ephemeral_by_room, tags_by_room, ephemeral_by_room, tags_by_room,
account_data_by_room): account_data_by_room,
""" Get the incremental delta needed to bring the client up to date for batch, full_state=False):
the room. Gives the client the most recent events and the changes to
state.
Returns:
A Deferred JoinedSyncResult
"""
logger.debug("Doing incremental sync for room %s between %s and %s",
room_id, since_token, now_token)
# TODO(mjark): Check for redactions we might have missed.
batch = yield self.load_filtered_recents(
room_id, sync_config, now_token, since_token,
)
logging.debug("Recents %r", batch)
current_state = yield self.get_state_at(room_id, now_token)
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
state = yield self.compute_state_delta( state = yield self.compute_state_delta(
since_token=since_token, room_id, batch, sync_config, since_token, now_token,
previous_state=state_at_previous_sync, full_state=full_state
current_state=current_state,
) )
just_joined = yield self.check_joined_room(sync_config, state) account_data = self.account_data_for_room(
if just_joined: room_id, tags_by_room, account_data_by_room
state = yield self.get_state_at(room_id, now_token) )
account_data = sync_config.filter_collection.filter_room_account_data(
account_data
)
ephemeral = sync_config.filter_collection.filter_room_ephemeral(
ephemeral_by_room.get(room_id, [])
)
unread_notifications = {}
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=state, state=state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral,
account_data=self.account_data_for_room( account_data=account_data,
room_id, tags_by_room, account_data_by_room unread_notifications=unread_notifications,
),
) )
logging.debug("Room sync: %r", room_sync) if room_sync:
notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config
)
if notifs is not None:
unread_notifications["notification_count"] = notifs["notify_count"]
unread_notifications["highlight_count"] = notifs["highlight_count"]
logger.debug("Room sync: %r", room_sync)
defer.returnValue(room_sync) defer.returnValue(room_sync)
@defer.inlineCallbacks @defer.inlineCallbacks
def incremental_sync_for_archived_room(self, sync_config, leave_event, def incremental_sync_for_archived_room(self, sync_config, room_id, leave_event_id,
since_token, tags_by_room, since_token, tags_by_room,
account_data_by_room): account_data_by_room, full_state,
leave_token=None):
""" Get the incremental delta needed to bring the client up to date for """ Get the incremental delta needed to bring the client up to date for
the archived room. the archived room.
Returns: Returns:
A Deferred ArchivedSyncResult A Deferred ArchivedSyncResult
""" """
if not leave_token:
stream_token = yield self.store.get_stream_token_for_event( stream_token = yield self.store.get_stream_token_for_event(
leave_event.event_id leave_event_id
) )
leave_token = since_token.copy_and_replace("room_key", stream_token) leave_token = since_token.copy_and_replace("room_key", stream_token)
if since_token and since_token.is_after(leave_token):
defer.returnValue(None)
batch = yield self.load_filtered_recents( batch = yield self.load_filtered_recents(
leave_event.room_id, sync_config, leave_token, since_token, room_id, sync_config, leave_token, since_token,
) )
logging.debug("Recents %r", batch) logger.debug("Recents %r", batch)
state_events_at_leave = yield self.store.get_state_for_event(
leave_event.event_id
)
state_at_previous_sync = yield self.get_state_at(
leave_event.room_id, stream_position=since_token
)
state_events_delta = yield self.compute_state_delta( state_events_delta = yield self.compute_state_delta(
since_token=since_token, room_id, batch, sync_config, since_token, leave_token,
previous_state=state_at_previous_sync, full_state=full_state
current_state=state_events_at_leave, )
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
account_data
) )
room_sync = ArchivedSyncResult( room_sync = ArchivedSyncResult(
room_id=leave_event.room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=state_events_delta, state=state_events_delta,
account_data=self.account_data_for_room( account_data=account_data,
leave_event.room_id, tags_by_room, account_data_by_room
),
) )
logging.debug("Room sync: %r", room_sync) logger.debug("Room sync: %r", room_sync)
defer.returnValue(room_sync) defer.returnValue(room_sync)
@ -804,15 +801,19 @@ class SyncHandler(BaseHandler):
state = {} state = {}
defer.returnValue(state) defer.returnValue(state)
def compute_state_delta(self, since_token, previous_state, current_state): @defer.inlineCallbacks
""" Works out the differnce in state between the current state and the def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
state the client got when it last performed a sync. full_state):
""" Works out the differnce in state between the start of the timeline
and the previous sync.
:param str since_token: the point we are comparing against :param str room_id
:param dict[(str,str), synapse.events.FrozenEvent] previous_state: the :param TimelineBatch batch: The timeline batch for the room that will
state to compare to be sent to the user.
:param dict[(str,str), synapse.events.FrozenEvent] current_state: the :param sync_config
new state :param str since_token: Token of the end of the previous batch. May be None.
:param str now_token: Token of the end of the current batch.
:param bool full_state: Whether to force returning the full state.
:returns A new event dictionary :returns A new event dictionary
""" """
@ -821,12 +822,53 @@ class SyncHandler(BaseHandler):
# updates even if they occured logically before the previous event. # updates even if they occured logically before the previous event.
# TODO(mjark) Check for new redactions in the state events. # TODO(mjark) Check for new redactions in the state events.
state_delta = {} with Measure(self.clock, "compute_state_delta"):
for key, event in current_state.iteritems(): if full_state:
if (key not in previous_state or if batch:
previous_state[key].event_id != event.event_id): state = yield self.store.get_state_for_event(
state_delta[key] = event batch.events[0].event_id
return state_delta )
else:
state = yield self.get_state_at(
room_id, stream_position=now_token
)
timeline_state = {
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
state = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state,
previous={},
)
elif batch.limited:
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
state_at_timeline_start = yield self.store.get_state_for_event(
batch.events[0].event_id
)
timeline_state = {
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
state = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state_at_timeline_start,
previous=state_at_previous_sync,
)
else:
state = {}
defer.returnValue({
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values())
})
def check_joined_room(self, sync_config, state_delta): def check_joined_room(self, sync_config, state_delta):
""" """
@ -845,3 +887,68 @@ class SyncHandler(BaseHandler):
if join_event.content["membership"] == Membership.JOIN: if join_event.content["membership"] == Membership.JOIN:
return True return True
return False return False
@defer.inlineCallbacks
def unread_notifs_for_room_id(self, room_id, sync_config):
with Measure(self.clock, "unread_notifs_for_room_id"):
last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
room_id=room_id,
receipt_type="m.read"
)
notifs = []
if last_unread_event_id:
notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
room_id, sync_config.user.to_string(), last_unread_event_id
)
defer.returnValue(notifs)
# There is no new information in this period, so your notification
# count is whatever it was last time.
defer.returnValue(None)
def _action_has_highlight(actions):
for action in actions:
try:
if action.get("set_tweak", None) == "highlight":
return action.get("value", True)
except AttributeError:
pass
return False
def _calculate_state(timeline_contains, timeline_start, previous):
"""Works out what state to include in a sync response.
Args:
timeline_contains (dict): state in the timeline
timeline_start (dict): state at the start of the timeline
previous (dict): state at the end of the previous sync (or empty dict
if this is an initial sync)
Returns:
dict
"""
event_id_to_state = {
e.event_id: e
for e in itertools.chain(
timeline_contains.values(),
previous.values(),
timeline_start.values(),
)
}
tc_ids = set(e.event_id for e in timeline_contains.values())
p_ids = set(e.event_id for e in previous.values())
ts_ids = set(e.event_id for e in timeline_start.values())
state_ids = (ts_ids - p_ids) - tc_ids
evs = (event_id_to_state[e] for e in state_ids)
return {
(e.type, e.state_key): e
for e in evs
}

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -19,6 +19,7 @@ from ._base import BaseHandler
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.metrics import Measure
from synapse.types import UserID from synapse.types import UserID
import logging import logging
@ -222,6 +223,7 @@ class TypingNotificationHandler(BaseHandler):
class TypingNotificationEventSource(object): class TypingNotificationEventSource(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.clock = hs.get_clock()
self._handler = None self._handler = None
self._room_member_handler = None self._room_member_handler = None
@ -247,6 +249,7 @@ class TypingNotificationEventSource(object):
} }
def get_new_events(self, from_key, room_ids, **kwargs): def get_new_events(self, from_key, room_ids, **kwargs):
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key) from_key = int(from_key)
handler = self.handler() handler = self.handler()

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,7 +17,7 @@ from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.error import ConnectError from twisted.internet.error import ConnectError
from twisted.names import client, dns from twisted.names import client, dns
from twisted.names.error import DNSNameError from twisted.names.error import DNSNameError, DomainError
import collections import collections
import logging import logging
@ -27,6 +27,14 @@ import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SERVER_CACHE = {}
_Server = collections.namedtuple(
"_Server", "priority weight host port"
)
def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout=None): timeout=None):
"""Construct an endpoint for the given matrix destination. """Construct an endpoint for the given matrix destination.
@ -73,10 +81,6 @@ class SRVClientEndpoint(object):
Implements twisted.internet.interfaces.IStreamClientEndpoint. Implements twisted.internet.interfaces.IStreamClientEndpoint.
""" """
_Server = collections.namedtuple(
"_Server", "priority weight host port"
)
def __init__(self, reactor, service, domain, protocol="tcp", def __init__(self, reactor, service, domain, protocol="tcp",
default_port=None, endpoint=TCP4ClientEndpoint, default_port=None, endpoint=TCP4ClientEndpoint,
endpoint_kw_args={}): endpoint_kw_args={}):
@ -84,7 +88,7 @@ class SRVClientEndpoint(object):
self.service_name = "_%s._%s.%s" % (service, protocol, domain) self.service_name = "_%s._%s.%s" % (service, protocol, domain)
if default_port is not None: if default_port is not None:
self.default_server = self._Server( self.default_server = _Server(
host=domain, host=domain,
port=default_port, port=default_port,
priority=0, priority=0,
@ -101,32 +105,8 @@ class SRVClientEndpoint(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def fetch_servers(self): def fetch_servers(self):
try:
answers, auth, add = yield client.lookupService(self.service_name)
except DNSNameError:
answers = []
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name('.')):
raise ConnectError("Service %s unavailable", self.service_name)
self.servers = []
self.used_servers = [] self.used_servers = []
self.servers = yield resolve_service(self.service_name)
for answer in answers:
if answer.type != dns.SRV or not answer.payload:
continue
payload = answer.payload
self.servers.append(self._Server(
host=str(payload.target),
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight)
))
self.servers.sort()
def pick_server(self): def pick_server(self):
if not self.servers: if not self.servers:
@ -170,3 +150,64 @@ class SRVClientEndpoint(object):
) )
connection = yield endpoint.connect(protocolFactory) connection = yield endpoint.connect(protocolFactory)
defer.returnValue(connection) defer.returnValue(connection)
@defer.inlineCallbacks
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
servers = []
try:
try:
answers, _, _ = yield dns_client.lookupService(service_name)
except DNSNameError:
defer.returnValue([])
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name('.')):
raise ConnectError("Service %s unavailable", service_name)
for answer in answers:
if answer.type != dns.SRV or not answer.payload:
continue
payload = answer.payload
host = str(payload.target)
try:
answers, _, _ = yield dns_client.lookupAddress(host)
except DNSNameError:
continue
ips = [
answer.payload.dottedQuad()
for answer in answers
if answer.type == dns.A and answer.payload
]
for ip in ips:
servers.append(_Server(
host=ip,
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight)
))
servers.sort()
cache[service_name] = list(servers)
except DomainError as e:
# We failed to resolve the name (other than a NameError)
# Try something in the cache, else rereaise
cache_entry = cache.get(service_name, None)
if cache_entry:
logger.warn(
"Failed to resolve %r, falling back to cache. %r",
service_name, e
)
servers = list(cache_entry)
else:
raise e
defer.returnValue(servers)

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -152,7 +152,7 @@ class MatrixFederationHttpClient(object):
return self.clock.time_bound_deferred( return self.clock.time_bound_deferred(
request_deferred, request_deferred,
time_out=timeout/1000. if timeout else 60, time_out=timeout / 1000. if timeout else 60,
) )
response = yield preserve_context_over_fn( response = yield preserve_context_over_fn(

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -41,7 +41,7 @@ metrics = synapse.metrics.get_metrics_for(__name__)
incoming_requests_counter = metrics.register_counter( incoming_requests_counter = metrics.register_counter(
"requests", "requests",
labels=["method", "servlet"], labels=["method", "servlet", "tag"],
) )
outgoing_responses_counter = metrics.register_counter( outgoing_responses_counter = metrics.register_counter(
"responses", "responses",
@ -50,23 +50,23 @@ outgoing_responses_counter = metrics.register_counter(
response_timer = metrics.register_distribution( response_timer = metrics.register_distribution(
"response_time", "response_time",
labels=["method", "servlet"] labels=["method", "servlet", "tag"]
) )
response_ru_utime = metrics.register_distribution( response_ru_utime = metrics.register_distribution(
"response_ru_utime", labels=["method", "servlet"] "response_ru_utime", labels=["method", "servlet", "tag"]
) )
response_ru_stime = metrics.register_distribution( response_ru_stime = metrics.register_distribution(
"response_ru_stime", labels=["method", "servlet"] "response_ru_stime", labels=["method", "servlet", "tag"]
) )
response_db_txn_count = metrics.register_distribution( response_db_txn_count = metrics.register_distribution(
"response_db_txn_count", labels=["method", "servlet"] "response_db_txn_count", labels=["method", "servlet", "tag"]
) )
response_db_txn_duration = metrics.register_distribution( response_db_txn_duration = metrics.register_distribution(
"response_db_txn_duration", labels=["method", "servlet"] "response_db_txn_duration", labels=["method", "servlet", "tag"]
) )
@ -99,9 +99,8 @@ def request_handler(request_handler):
request_context.request = request_id request_context.request = request_id
with request.processing(): with request.processing():
try: try:
d = request_handler(self, request) with PreserveLoggingContext(request_context):
with PreserveLoggingContext(): yield request_handler(self, request)
yield d
except CodeMessageException as e: except CodeMessageException as e:
code = e.code code = e.code
if isinstance(e, SynapseError): if isinstance(e, SynapseError):
@ -208,6 +207,9 @@ class JsonResource(HttpServer, resource.Resource):
if request.method == "OPTIONS": if request.method == "OPTIONS":
self._send_response(request, 200, {}) self._send_response(request, 200, {})
return return
start_context = LoggingContext.current_context()
# Loop through all the registered callbacks to check if the method # Loop through all the registered callbacks to check if the method
# and path regex match # and path regex match
for path_entry in self.path_regexs.get(request.method, []): for path_entry in self.path_regexs.get(request.method, []):
@ -226,7 +228,6 @@ class JsonResource(HttpServer, resource.Resource):
servlet_classname = servlet_instance.__class__.__name__ servlet_classname = servlet_instance.__class__.__name__
else: else:
servlet_classname = "%r" % callback servlet_classname = "%r" % callback
incoming_requests_counter.inc(request.method, servlet_classname)
args = [ args = [
urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups() urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups()
@ -237,21 +238,40 @@ class JsonResource(HttpServer, resource.Resource):
code, response = callback_return code, response = callback_return
self._send_response(request, code, response) self._send_response(request, code, response)
response_timer.inc_by(
self.clock.time_msec() - start, request.method, servlet_classname
)
try: try:
context = LoggingContext.current_context() context = LoggingContext.current_context()
tag = ""
if context:
tag = context.tag
if context != start_context:
logger.warn(
"Context have unexpectedly changed %r, %r",
context, self.start_context
)
return
incoming_requests_counter.inc(request.method, servlet_classname, tag)
response_timer.inc_by(
self.clock.time_msec() - start, request.method,
servlet_classname, tag
)
ru_utime, ru_stime = context.get_resource_usage() ru_utime, ru_stime = context.get_resource_usage()
response_ru_utime.inc_by(ru_utime, request.method, servlet_classname) response_ru_utime.inc_by(
response_ru_stime.inc_by(ru_stime, request.method, servlet_classname) ru_utime, request.method, servlet_classname, tag
)
response_ru_stime.inc_by(
ru_stime, request.method, servlet_classname, tag
)
response_db_txn_count.inc_by( response_db_txn_count.inc_by(
context.db_txn_count, request.method, servlet_classname context.db_txn_count, request.method, servlet_classname, tag
) )
response_db_txn_duration.inc_by( response_db_txn_duration.inc_by(
context.db_txn_duration, request.method, servlet_classname context.db_txn_duration, request.method, servlet_classname, tag
) )
except: except:
pass pass

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,10 +18,13 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor, ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import StreamToken from synapse.types import StreamToken
import synapse.metrics import synapse.metrics
from collections import namedtuple
import logging import logging
@ -63,14 +66,15 @@ class _NotifierUserStream(object):
so that it can remove itself from the indexes in the Notifier class. so that it can remove itself from the indexes in the Notifier class.
""" """
def __init__(self, user, rooms, current_token, time_now_ms, def __init__(self, user_id, rooms, current_token, time_now_ms,
appservice=None): appservice=None):
self.user = str(user) self.user_id = user_id
self.appservice = appservice self.appservice = appservice
self.rooms = set(rooms) self.rooms = set(rooms)
self.current_token = current_token self.current_token = current_token
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred()) self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key, stream_id, time_now_ms): def notify(self, stream_key, stream_id, time_now_ms):
@ -86,6 +90,8 @@ class _NotifierUserStream(object):
) )
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred noify_deferred = self.notify_deferred
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred()) self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token) noify_deferred.callback(self.current_token)
@ -98,7 +104,7 @@ class _NotifierUserStream(object):
lst = notifier.room_to_user_streams.get(room, set()) lst = notifier.room_to_user_streams.get(room, set())
lst.discard(self) lst.discard(self)
notifier.user_to_user_stream.pop(self.user) notifier.user_to_user_stream.pop(self.user_id)
if self.appservice: if self.appservice:
notifier.appservice_to_user_streams.get( notifier.appservice_to_user_streams.get(
@ -118,6 +124,11 @@ class _NotifierUserStream(object):
return _NotificationListener(self.notify_deferred.observe()) return _NotificationListener(self.notify_deferred.observe())
class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
def __nonzero__(self):
return bool(self.events)
class Notifier(object): class Notifier(object):
""" This class is responsible for notifying any listeners when there are """ This class is responsible for notifying any listeners when there are
new events available for it. new events available for it.
@ -177,8 +188,6 @@ class Notifier(object):
lambda: count(bool, self.appservice_to_user_streams.values()), lambda: count(bool, self.appservice_to_user_streams.values()),
) )
@log_function
@defer.inlineCallbacks
def on_new_room_event(self, event, room_stream_id, max_room_stream_id, def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]): extra_users=[]):
""" Used by handlers to inform the notifier something has happened """ Used by handlers to inform the notifier something has happened
@ -192,8 +201,7 @@ class Notifier(object):
until all previous events have been persisted before notifying until all previous events have been persisted before notifying
the client streams. the client streams.
""" """
yield run_on_reactor() with PreserveLoggingContext():
self.pending_new_room_events.append(( self.pending_new_room_events.append((
room_stream_id, event, extra_users room_stream_id, event, extra_users
)) ))
@ -244,15 +252,13 @@ class Notifier(object):
extra_streams=app_streams, extra_streams=app_streams,
) )
@defer.inlineCallbacks
@log_function
def on_new_event(self, stream_key, new_token, users=[], rooms=[], def on_new_event(self, stream_key, new_token, users=[], rooms=[],
extra_streams=set()): extra_streams=set()):
""" Used to inform listeners that something has happend event wise. """ Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
yield run_on_reactor() with PreserveLoggingContext():
user_streams = set() user_streams = set()
for user in users: for user in users:
@ -271,21 +277,20 @@ class Notifier(object):
logger.exception("Failed to notify listener") logger.exception("Failed to notify listener")
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user, timeout, callback, room_ids=None, def wait_for_events(self, user_id, timeout, callback, room_ids=None,
from_token=StreamToken("s0", "0", "0", "0", "0")): from_token=StreamToken("s0", "0", "0", "0", "0")):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """
user = str(user) user_stream = self.user_to_user_stream.get(user_id)
user_stream = self.user_to_user_stream.get(user)
if user_stream is None: if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user) appservice = yield self.store.get_app_service_by_user_id(user_id)
current_token = yield self.event_sources.get_current_token() current_token = yield self.event_sources.get_current_token()
if room_ids is None: if room_ids is None:
rooms = yield self.store.get_rooms_for_user(user) rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = [room.room_id for room in rooms] room_ids = [room.room_id for room in rooms]
user_stream = _NotifierUserStream( user_stream = _NotifierUserStream(
user=user, user_id=user_id,
rooms=room_ids, rooms=room_ids,
appservice=appservice, appservice=appservice,
current_token=current_token, current_token=current_token,
@ -302,7 +307,7 @@ class Notifier(object):
def timed_out(): def timed_out():
if listener: if listener:
listener.deferred.cancel() listener.deferred.cancel()
timer = self.clock.call_later(timeout/1000., timed_out) timer = self.clock.call_later(timeout / 1000., timed_out)
prev_token = from_token prev_token = from_token
while not result: while not result:
@ -319,6 +324,7 @@ class Notifier(object):
# that we don't miss any current_token updates. # that we don't miss any current_token updates.
prev_token = current_token prev_token = current_token
listener = user_stream.new_listener(prev_token) listener = user_stream.new_listener(prev_token)
with PreserveLoggingContext():
yield listener.deferred yield listener.deferred
except defer.CancelledError: except defer.CancelledError:
break break
@ -332,13 +338,18 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_events_for(self, user, pagination_config, timeout, def get_events_for(self, user, pagination_config, timeout,
only_room_events=False, only_keys=None,
is_guest=False, guest_room_id=None): is_guest=False, explicit_room_id=None):
""" For the given user and rooms, return any new events for them. If """ For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning. new events to happen before returning.
If `only_room_events` is `True` only room events will be returned. If `only_keys` is not None, events from keys will be sent down.
If explicit_room_id is not set, the user's joined rooms will be polled
for events.
If explicit_room_id is set, that room will be polled for events only if
it is world readable or the user has joined the room.
""" """
from_token = pagination_config.from_token from_token = pagination_config.from_token
if not from_token: if not from_token:
@ -346,20 +357,13 @@ class Notifier(object):
limit = pagination_config.limit limit = pagination_config.limit
room_ids = [] room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id)
if is_guest: is_peeking = not is_joined
if guest_room_id:
if not (yield self._is_world_readable(guest_room_id)):
raise AuthError(403, "Guest access not allowed")
room_ids = [guest_room_id]
else:
rooms = yield self.store.get_rooms_for_user(user.to_string())
room_ids = [room.room_id for room in rooms]
@defer.inlineCallbacks @defer.inlineCallbacks
def check_for_updates(before_token, after_token): def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token): if not after_token.is_after(before_token):
defer.returnValue(None) defer.returnValue(EventStreamResult([], (from_token, from_token)))
events = [] events = []
end_token = from_token end_token = from_token
@ -370,13 +374,14 @@ class Notifier(object):
after_id = getattr(after_token, keyname) after_id = getattr(after_token, keyname)
if before_id == after_id: if before_id == after_id:
continue continue
if only_room_events and name != "room": if only_keys and name not in only_keys:
continue continue
new_events, new_key = yield source.get_new_events( new_events, new_key = yield source.get_new_events(
user=user, user=user,
from_key=getattr(from_token, keyname), from_key=getattr(from_token, keyname),
limit=limit, limit=limit,
is_guest=is_guest, is_guest=is_peeking,
room_ids=room_ids, room_ids=room_ids,
) )
@ -385,27 +390,50 @@ class Notifier(object):
new_events = yield room_member_handler._filter_events_for_client( new_events = yield room_member_handler._filter_events_for_client(
user.to_string(), user.to_string(),
new_events, new_events,
is_guest=is_guest, is_peeking=is_peeking,
require_all_visible_for_guests=False
) )
events.extend(new_events) events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key) end_token = end_token.copy_and_replace(keyname, new_key)
if events: defer.returnValue(EventStreamResult(events, (from_token, end_token)))
defer.returnValue((events, (from_token, end_token)))
else:
defer.returnValue(None)
result = yield self.wait_for_events( user_id_for_stream = user.to_string()
user, timeout, check_for_updates, room_ids=room_ids, from_token=from_token if is_peeking:
# Internally, the notifier keeps an event stream per user_id.
# This is used by both /sync and /events.
# We want /events to be used for peeking independently of /sync,
# without polluting its contents. So we invent an illegal user ID
# (which thus cannot clash with any real users) for keying peeking
# over /events.
#
# I am sorry for what I have done.
user_id_for_stream = "_PEEKING_%s_%s" % (
explicit_room_id, user_id_for_stream
) )
if result is None: result = yield self.wait_for_events(
result = ([], (from_token, from_token)) user_id_for_stream,
timeout,
check_for_updates,
room_ids=room_ids,
from_token=from_token,
)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def _get_room_ids(self, user, explicit_room_id):
joined_rooms = yield self.store.get_rooms_for_user(user.to_string())
joined_room_ids = map(lambda r: r.room_id, joined_rooms)
if explicit_room_id:
if explicit_room_id in joined_room_ids:
defer.returnValue(([explicit_room_id], True))
if (yield self._is_world_readable(explicit_room_id)):
defer.returnValue(([explicit_room_id], False))
raise AuthError(403, "Non-joined access not allowed")
defer.returnValue((joined_room_ids, True))
@defer.inlineCallbacks @defer.inlineCallbacks
def _is_world_readable(self, room_id): def _is_world_readable(self, room_id):
state = yield self.hs.get_state_handler().get_current_state( state = yield self.hs.get_state_handler().get_current_state(
@ -433,7 +461,7 @@ class Notifier(object):
@log_function @log_function
def _register_with_keys(self, user_stream): def _register_with_keys(self, user_stream):
self.user_to_user_stream[user_stream.user] = user_stream self.user_to_user_stream[user_stream.user_id] = user_stream
for room in user_stream.rooms: for room in user_stream.rooms:
s = self.room_to_user_streams.setdefault(room, set()) s = self.room_to_user_streams.setdefault(room, set())

Some files were not shown because too many files have changed in this diff Show More