Merge branch 'release-v0.18.5' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2016-12-16 10:40:10 +00:00
commit f5a4001bb1
105 changed files with 2917 additions and 1213 deletions

17
.travis.yml Normal file
View File

@ -0,0 +1,17 @@
sudo: false
language: python
python: 2.7
# tell travis to cache ~/.cache/pip
cache: pip
env:
- TOX_ENV=packaging
- TOX_ENV=pep8
- TOX_ENV=py27
install:
- pip install tox
script:
- tox -e $TOX_ENV

View File

@ -1,3 +1,68 @@
Changes in synapse v0.18.5 (2016-12-16)
=======================================
Bug fixes:
* Fix federation /backfill returning events it shouldn't (PR #1700)
* Fix crash in url preview (PR #1701)
Changes in synapse v0.18.5-rc3 (2016-12-13)
===========================================
Features:
* Add support for E2E for guests (PR #1653)
* Add new API appservice specific public room list (PR #1676)
* Add new room membership APIs (PR #1680)
Changes:
* Enable guest access for private rooms by default (PR #653)
* Limit the number of events that can be created on a given room concurrently
(PR #1620)
* Log the args that we have on UI auth completion (PR #1649)
* Stop generating refresh_tokens (PR #1654)
* Stop putting a time caveat on access tokens (PR #1656)
* Remove unspecced GET endpoints for e2e keys (PR #1694)
Bug fixes:
* Fix handling of 500 and 429's over federation (PR #1650)
* Fix Content-Type header parsing (PR #1660)
* Fix error when previewing sites that include unicode, thanks to kyrias (PR
#1664)
* Fix some cases where we drop read receipts (PR #1678)
* Fix bug where calls to ``/sync`` didn't correctly timeout (PR #1683)
* Fix bug where E2E key query would fail if a single remote host failed (PR
#1686)
Changes in synapse v0.18.5-rc2 (2016-11-24)
===========================================
Bug fixes:
* Don't send old events over federation, fixes bug in -rc1.
Changes in synapse v0.18.5-rc1 (2016-11-24)
===========================================
Features:
* Implement "event_fields" in filters (PR #1638)
Changes:
* Use external ldap auth pacakge (PR #1628)
* Split out federation transaction sending to a worker (PR #1635)
* Fail with a coherent error message if `/sync?filter=` is invalid (PR #1636)
* More efficient notif count queries (PR #1644)
Changes in synapse v0.18.4 (2016-11-22) Changes in synapse v0.18.4 (2016-11-22)
======================================= =======================================

View File

@ -120,6 +120,7 @@ Installing prerequisites on Mac OS X::
xcode-select --install xcode-select --install
sudo easy_install pip sudo easy_install pip
sudo pip install virtualenv sudo pip install virtualenv
brew install pkg-config libffi
Installing prerequisites on Raspbian:: Installing prerequisites on Raspbian::
@ -136,6 +137,10 @@ Installing prerequisites on openSUSE::
sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \ sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
python-devel libffi-devel libopenssl-devel libjpeg62-devel python-devel libffi-devel libopenssl-devel libjpeg62-devel
Installing prerequisites on OpenBSD::
doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \
libxslt
To install the synapse homeserver run:: To install the synapse homeserver run::
virtualenv -p python2.7 ~/.synapse virtualenv -p python2.7 ~/.synapse
@ -369,6 +374,32 @@ Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Mo
- Ports: ``cd /usr/ports/net/py-matrix-synapse && make install clean`` - Ports: ``cd /usr/ports/net/py-matrix-synapse && make install clean``
- Packages: ``pkg install py27-matrix-synapse`` - Packages: ``pkg install py27-matrix-synapse``
OpenBSD
-------
There is currently no port for OpenBSD. Additionally, OpenBSD's security
settings require a slightly more difficult installation process.
1) Create a new directory in ``/usr/local`` called ``_synapse``. Also, create a
new user called ``_synapse`` and set that directory as the new user's home.
This is required because, by default, OpenBSD only allows binaries which need
write and execute permissions on the same memory space to be run from
``/usr/local``.
2) ``su`` to the new ``_synapse`` user and change to their home directory.
3) Create a new virtualenv: ``virtualenv -p python2.7 ~/.synapse``
4) Source the virtualenv configuration located at
``/usr/local/_synapse/.synapse/bin/activate``. This is done in ``ksh`` by
using the ``.`` command, rather than ``bash``'s ``source``.
5) Optionally, use ``pip`` to install ``lxml``, which Synapse needs to parse
webpages for their titles.
6) Use ``pip`` to install this repository: ``pip install
https://github.com/matrix-org/synapse/tarball/master``
7) Optionally, change ``_synapse``'s shell to ``/bin/false`` to reduce the
chance of a compromised Synapse server being used to take over your box.
After this, you may proceed with the rest of the install directions.
NixOS NixOS
----- -----

View File

@ -22,3 +22,4 @@ export SYNAPSE_CACHE_FACTOR=1
--federation-reader \ --federation-reader \
--client-reader \ --client-reader \
--appservice \ --appservice \
--federation-sender \

View File

@ -15,6 +15,6 @@ tox -e py27 --notest -v
TOX_BIN=$TOX_DIR/py27/bin TOX_BIN=$TOX_DIR/py27/bin
$TOX_BIN/pip install setuptools $TOX_BIN/pip install setuptools
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install { python synapse/python_dependencies.py
$TOX_BIN/pip install lxml echo lxml psycopg2
$TOX_BIN/pip install psycopg2 } | xargs $TOX_BIN/pip install

View File

@ -23,6 +23,45 @@ import sys
here = os.path.abspath(os.path.dirname(__file__)) here = os.path.abspath(os.path.dirname(__file__))
# Some notes on `setup.py test`:
#
# Once upon a time we used to try to make `setup.py test` run `tox` to run the
# tests. That's a bad idea for three reasons:
#
# 1: `setup.py test` is supposed to find out whether the tests work in the
# *current* environmentt, not whatever tox sets up.
# 2: Empirically, trying to install tox during the test run wasn't working ("No
# module named virtualenv").
# 3: The tox documentation advises against it[1].
#
# Even further back in time, we used to use setuptools_trial [2]. That has its
# own set of issues: for instance, it requires installation of Twisted to build
# an sdist (because the recommended mode of usage is to add it to
# `setup_requires`). That in turn means that in order to successfully run tox
# you have to have the python header files installed for whichever version of
# python tox uses (which is python3 on recent ubuntus, for example).
#
# So, for now at least, we stick with what appears to be the convention among
# Twisted projects, and don't attempt to do anything when someone runs
# `setup.py test`; instead we direct people to run `trial` directly if they
# care.
#
# [1]: http://tox.readthedocs.io/en/2.5.0/example/basic.html#integration-with-setup-py-test-command
# [2]: https://pypi.python.org/pypi/setuptools_trial
class TestCommand(Command):
user_options = []
def initialize_options(self):
pass
def finalize_options(self):
pass
def run(self):
print ("""Synapse's tests cannot be run via setup.py. To run them, try:
PYTHONPATH="." trial tests
""")
def read_file(path_segments): def read_file(path_segments):
"""Read a file from the package. Takes a list of strings to join to """Read a file from the package. Takes a list of strings to join to
make the path""" make the path"""
@ -39,38 +78,6 @@ def exec_file(path_segments):
return result return result
class Tox(Command):
user_options = [('tox-args=', 'a', "Arguments to pass to tox")]
def initialize_options(self):
self.tox_args = None
def finalize_options(self):
self.test_args = []
self.test_suite = True
def run(self):
#import here, cause outside the eggs aren't loaded
try:
import tox
except ImportError:
try:
self.distribution.fetch_build_eggs("tox")
import tox
except:
raise RuntimeError(
"The tests need 'tox' to run. Please install 'tox'."
)
import shlex
args = self.tox_args
if args:
args = shlex.split(self.tox_args)
else:
args = []
errno = tox.cmdline(args=args)
sys.exit(errno)
version = exec_file(("synapse", "__init__.py"))["__version__"] version = exec_file(("synapse", "__init__.py"))["__version__"]
dependencies = exec_file(("synapse", "python_dependencies.py")) dependencies = exec_file(("synapse", "python_dependencies.py"))
long_description = read_file(("README.rst",)) long_description = read_file(("README.rst",))
@ -86,5 +93,5 @@ setup(
zip_safe=False, zip_safe=False,
long_description=long_description, long_description=long_description,
scripts=["synctl"] + glob.glob("scripts/*"), scripts=["synctl"] + glob.glob("scripts/*"),
cmdclass={'test': Tox}, cmdclass={'test': TestCommand},
) )

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.18.4" __version__ = "0.18.5"

View File

@ -39,6 +39,9 @@ AuthEventTypes = (
EventTypes.ThirdPartyInvite, EventTypes.ThirdPartyInvite,
) )
# guests always get this device id.
GUEST_DEVICE_ID = "guest_device"
class Auth(object): class Auth(object):
""" """
@ -51,17 +54,6 @@ class Auth(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
# Docs for these currently lives at
# github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
# In addition, we have type == delete_pusher which grants access only to
# delete pushers.
self._KNOWN_CAVEAT_PREFIXES = set([
"gen = ",
"guest = ",
"type = ",
"time < ",
"user_id = ",
])
@defer.inlineCallbacks @defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True): def check_from_context(self, event, context, do_sig_check=True):
@ -685,31 +677,28 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_access_token(self, token, rights="access"): def get_user_by_access_token(self, token, rights="access"):
""" Get a registered user's ID. """ Validate access token and get user_id from it
Args: Args:
token (str): The access token to get the user by. token (str): The access token to get the user by.
rights (str): The operation being performed; the access token must
allow this.
Returns: Returns:
dict : dict that includes the user and the ID of their access token. dict : dict that includes the user and the ID of their access token.
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
try: try:
ret = yield self.get_user_from_macaroon(token, rights) macaroon = pymacaroons.Macaroon.deserialize(token)
except AuthError: except Exception: # deserialize can throw more-or-less anything
# TODO(daniel): Remove this fallback when all existing access tokens # doesn't look like a macaroon: treat it as an opaque token which
# have been re-issued as macaroons. # must be in the database.
if self.hs.config.expire_access_token: # TODO: it would be nice to get rid of this, but apparently some
raise # people use access tokens which aren't macaroons
ret = yield self._look_up_user_by_access_token(token) r = yield self._look_up_user_by_access_token(token)
defer.returnValue(r)
defer.returnValue(ret)
@defer.inlineCallbacks
def get_user_from_macaroon(self, macaroon_str, rights="access"):
try: try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
user_id = self.get_user_id_from_macaroon(macaroon) user_id = self.get_user_id_from_macaroon(macaroon)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
@ -724,11 +713,36 @@ class Auth(object):
guest = True guest = True
if guest: if guest:
# Guest access tokens are not stored in the database (there can
# only be one access token per guest, anyway).
#
# In order to prevent guest access tokens being used as regular
# user access tokens (and hence getting around the invalidation
# process), we look up the user id and check that it is indeed
# a guest user.
#
# It would of course be much easier to store guest access
# tokens in the database as well, but that would break existing
# guest tokens.
stored_user = yield self.store.get_user_by_id(user_id)
if not stored_user:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unknown user_id %s" % user_id,
errcode=Codes.UNKNOWN_TOKEN
)
if not stored_user["is_guest"]:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Guest access token used for regular user",
errcode=Codes.UNKNOWN_TOKEN
)
ret = { ret = {
"user": user, "user": user,
"is_guest": True, "is_guest": True,
"token_id": None, "token_id": None,
"device_id": None, # all guests get the same device id
"device_id": GUEST_DEVICE_ID,
} }
elif rights == "delete_pusher": elif rights == "delete_pusher":
# We don't store these tokens in the database # We don't store these tokens in the database
@ -750,7 +764,7 @@ class Auth(object):
# macaroon. They probably should be. # macaroon. They probably should be.
# TODO: build the dictionary from the macaroon once the # TODO: build the dictionary from the macaroon once the
# above are fixed # above are fixed
ret = yield self._look_up_user_by_access_token(macaroon_str) ret = yield self._look_up_user_by_access_token(token)
if ret["user"] != user: if ret["user"] != user:
logger.error( logger.error(
"Macaroon user (%s) != DB user (%s)", "Macaroon user (%s) != DB user (%s)",
@ -798,27 +812,38 @@ class Auth(object):
Args: Args:
macaroon(pymacaroons.Macaroon): The macaroon to validate macaroon(pymacaroons.Macaroon): The macaroon to validate
type_string(str): The kind of token required (e.g. "access", "refresh", type_string(str): The kind of token required (e.g. "access",
"delete_pusher") "delete_pusher")
verify_expiry(bool): Whether to verify whether the macaroon has expired. verify_expiry(bool): Whether to verify whether the macaroon has expired.
This should really always be True, but no clients currently implement
token refresh, so we can't enforce expiry yet.
user_id (str): The user_id required user_id (str): The user_id required
""" """
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
# the verifier runs a test for every caveat on the macaroon, to check
# that it is met for the current request. Each caveat must match at
# least one of the predicates specified by satisfy_exact or
# specify_general.
v.satisfy_exact("gen = 1") v.satisfy_exact("gen = 1")
v.satisfy_exact("type = " + type_string) v.satisfy_exact("type = " + type_string)
v.satisfy_exact("user_id = %s" % user_id) v.satisfy_exact("user_id = %s" % user_id)
v.satisfy_exact("guest = true") v.satisfy_exact("guest = true")
# verify_expiry should really always be True, but there exist access
# tokens in the wild which expire when they should not, so we can't
# enforce expiry yet (so we have to allow any caveat starting with
# 'time < ' in access tokens).
#
# On the other hand, short-term login tokens (as used by CAS login, for
# example) have an expiry time which we do want to enforce.
if verify_expiry: if verify_expiry:
v.satisfy_general(self._verify_expiry) v.satisfy_general(self._verify_expiry)
else: else:
v.satisfy_general(lambda c: c.startswith("time < ")) v.satisfy_general(lambda c: c.startswith("time < "))
v.verify(macaroon, self.hs.config.macaroon_secret_key) # access_tokens include a nonce for uniqueness: any value is acceptable
v.satisfy_general(lambda c: c.startswith("nonce = "))
v = pymacaroons.Verifier()
v.satisfy_general(self._verify_recognizes_caveats)
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self.hs.config.macaroon_secret_key)
def _verify_expiry(self, caveat): def _verify_expiry(self, caveat):
@ -829,15 +854,6 @@ class Auth(object):
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
return now < expiry return now < expiry
def _verify_recognizes_caveats(self, caveat):
first_space = caveat.find(" ")
if first_space < 0:
return False
second_space = caveat.find(" ", first_space + 1)
if second_space < 0:
return False
return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES
@defer.inlineCallbacks @defer.inlineCallbacks
def _look_up_user_by_access_token(self, token): def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token) ret = yield self.store.get_user_by_access_token(token)

View File

@ -39,6 +39,7 @@ class Codes(object):
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
CAPTCHA_INVALID = "M_CAPTCHA_INVALID" CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
MISSING_PARAM = "M_MISSING_PARAM" MISSING_PARAM = "M_MISSING_PARAM"
INVALID_PARAM = "M_INVALID_PARAM"
TOO_LARGE = "M_TOO_LARGE" TOO_LARGE = "M_TOO_LARGE"
EXCLUSIVE = "M_EXCLUSIVE" EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"

View File

@ -71,6 +71,21 @@ class Filtering(object):
if key in user_filter_json["room"]: if key in user_filter_json["room"]:
self._check_definition(user_filter_json["room"][key]) self._check_definition(user_filter_json["room"][key])
if "event_fields" in user_filter_json:
if type(user_filter_json["event_fields"]) != list:
raise SynapseError(400, "event_fields must be a list of strings")
for field in user_filter_json["event_fields"]:
if not isinstance(field, basestring):
raise SynapseError(400, "Event field must be a string")
# Don't allow '\\' in event field filters. This makes matching
# events a lot easier as we can then use a negative lookbehind
# assertion to split '\.' If we allowed \\ then it would
# incorrectly split '\\.' See synapse.events.utils.serialize_event
if r'\\' in field:
raise SynapseError(
400, r'The escape character \ cannot itself be escaped'
)
def _check_definition_room_lists(self, definition): def _check_definition_room_lists(self, definition):
"""Check that "rooms" and "not_rooms" are lists of room ids if they """Check that "rooms" and "not_rooms" are lists of room ids if they
are present are present
@ -152,6 +167,7 @@ class FilterCollection(object):
self.include_leave = filter_json.get("room", {}).get( self.include_leave = filter_json.get("room", {}).get(
"include_leave", False "include_leave", False
) )
self.event_fields = filter_json.get("event_fields", [])
def __repr__(self): def __repr__(self):
return "<FilterCollection %s>" % (json.dumps(self._filter_json),) return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
@ -186,6 +202,26 @@ class FilterCollection(object):
def filter_room_account_data(self, events): def filter_room_account_data(self, events):
return self._room_account_data.filter(self._room_filter.filter(events)) return self._room_account_data.filter(self._room_filter.filter(events))
def blocks_all_presence(self):
return (
self._presence_filter.filters_all_types() or
self._presence_filter.filters_all_senders()
)
def blocks_all_room_ephemeral(self):
return (
self._room_ephemeral_filter.filters_all_types() or
self._room_ephemeral_filter.filters_all_senders() or
self._room_ephemeral_filter.filters_all_rooms()
)
def blocks_all_room_timeline(self):
return (
self._room_timeline_filter.filters_all_types() or
self._room_timeline_filter.filters_all_senders() or
self._room_timeline_filter.filters_all_rooms()
)
class Filter(object): class Filter(object):
def __init__(self, filter_json): def __init__(self, filter_json):
@ -202,6 +238,15 @@ class Filter(object):
self.contains_url = self.filter_json.get("contains_url", None) self.contains_url = self.filter_json.get("contains_url", None)
def filters_all_types(self):
return "*" in self.not_types
def filters_all_senders(self):
return "*" in self.not_senders
def filters_all_rooms(self):
return "*" in self.not_rooms
def check(self, event): def check(self, event):
"""Checks whether the filter matches the given event. """Checks whether the filter matches the given event.

View File

@ -0,0 +1,331 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.server import HomeServer
from synapse.config._base import ConfigError
from synapse.config.logger import setup_logging
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.http.site import SynapseSite
from synapse.federation import send_queue
from synapse.federation.units import Edu
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse import events
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
import ujson as json
logger = logging.getLogger("synapse.app.appservice")
class FederationSenderSlaveStore(
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
SlavedRegistrationStore,
):
pass
class FederationSenderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse federation_sender now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
send_handler = FederationSenderHandler(self)
send_handler.on_start()
while True:
try:
args = store.stream_positions()
args.update((yield send_handler.stream_positions()))
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
yield send_handler.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(30)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse federation sender", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.federation_sender"
setup_logging(config.worker_log_config, config.worker_log_file)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
if config.send_federation:
sys.stderr.write(
"\nThe send_federation must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
"\nPlease add ``send_federation: false`` to the main config"
"\n"
)
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
config.send_federation = True
tls_server_context_factory = context_factory.ServerContextFactory(config)
ps = FederationSenderServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ps.setup()
ps.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ps.replicate()
ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-federation-sender",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
class FederationSenderHandler(object):
"""Processes the replication stream and forwards the appropriate entries
to the federation sender.
"""
def __init__(self, hs):
self.store = hs.get_datastore()
self.federation_sender = hs.get_federation_sender()
self._room_serials = {}
self._room_typing = {}
def on_start(self):
# There may be some events that are persisted but haven't been sent,
# so send them now.
self.federation_sender.notify_new_events(
self.store.get_room_max_stream_ordering()
)
@defer.inlineCallbacks
def stream_positions(self):
stream_id = yield self.store.get_federation_out_pos("federation")
defer.returnValue({
"federation": stream_id,
# Ack stuff we've "processed", this should only be called from
# one process.
"federation_ack": stream_id,
})
@defer.inlineCallbacks
def process_replication(self, result):
# The federation stream contains things that we want to send out, e.g.
# presence, typing, etc.
fed_stream = result.get("federation")
if fed_stream:
latest_id = int(fed_stream["position"])
# The federation stream containis a bunch of different types of
# rows that need to be handled differently. We parse the rows, put
# them into the appropriate collection and then send them off.
presence_to_send = {}
keyed_edus = {}
edus = {}
failures = {}
device_destinations = set()
# Parse the rows in the stream
for row in fed_stream["rows"]:
position, typ, content_js = row
content = json.loads(content_js)
if typ == send_queue.PRESENCE_TYPE:
destination = content["destination"]
state = UserPresenceState.from_dict(content["state"])
presence_to_send.setdefault(destination, []).append(state)
elif typ == send_queue.KEYED_EDU_TYPE:
key = content["key"]
edu = Edu(**content["edu"])
keyed_edus.setdefault(
edu.destination, {}
)[(edu.destination, tuple(key))] = edu
elif typ == send_queue.EDU_TYPE:
edu = Edu(**content)
edus.setdefault(edu.destination, []).append(edu)
elif typ == send_queue.FAILURE_TYPE:
destination = content["destination"]
failure = content["failure"]
failures.setdefault(destination, []).append(failure)
elif typ == send_queue.DEVICE_MESSAGE_TYPE:
device_destinations.add(content["destination"])
else:
raise Exception("Unrecognised federation type: %r", typ)
# We've finished collecting, send everything off
for destination, states in presence_to_send.items():
self.federation_sender.send_presence(destination, states)
for destination, edu_map in keyed_edus.items():
for key, edu in edu_map.items():
self.federation_sender.send_edu(
edu.destination, edu.edu_type, edu.content, key=key,
)
for destination, edu_list in edus.items():
for edu in edu_list:
self.federation_sender.send_edu(
edu.destination, edu.edu_type, edu.content, key=None,
)
for destination, failure_list in failures.items():
for failure in failure_list:
self.federation_sender.send_failure(destination, failure)
for destination in device_destinations:
self.federation_sender.send_device_messages(destination)
# Record where we are in the stream.
yield self.store.update_federation_out_pos(
"federation", latest_id
)
# We also need to poke the federation sender when new events happen
event_stream = result.get("events")
if event_stream:
latest_pos = event_stream["position"]
self.federation_sender.notify_new_events(latest_pos)
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@ -89,6 +89,9 @@ class ApplicationService(object):
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.id = id self.id = id
if "|" in self.id:
raise Exception("application service ID cannot contain '|' character")
# .protocols is a publicly visible field # .protocols is a publicly visible field
if protocols: if protocols:
self.protocols = set(protocols) self.protocols = set(protocols)

View File

@ -19,6 +19,7 @@ from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.types import ThirdPartyInstanceID
import logging import logging
import urllib import urllib
@ -177,6 +178,13 @@ class ApplicationServiceApi(SimpleHttpClient):
" valid result", uri) " valid result", uri)
defer.returnValue(None) defer.returnValue(None)
for instance in info.get("instances", []):
network_id = instance.get("network_id", None)
if network_id is not None:
instance["instance_id"] = ThirdPartyInstanceID(
service.id, network_id,
).to_string()
defer.returnValue(info) defer.returnValue(info)
except Exception as ex: except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s", logger.warning("query_3pe_protocol to %s threw exception %s",

View File

@ -50,6 +50,7 @@ handlers:
console: console:
class: logging.StreamHandler class: logging.StreamHandler
formatter: precise formatter: precise
filters: [context]
loggers: loggers:
synapse: synapse:

View File

@ -27,17 +27,23 @@ class PasswordAuthProviderConfig(Config):
ldap_config = config.get("ldap_config", {}) ldap_config = config.get("ldap_config", {})
self.ldap_enabled = ldap_config.get("enabled", False) self.ldap_enabled = ldap_config.get("enabled", False)
if self.ldap_enabled: if self.ldap_enabled:
from synapse.util.ldap_auth_provider import LdapAuthProvider from ldap_auth_provider import LdapAuthProvider
parsed_config = LdapAuthProvider.parse_config(ldap_config) parsed_config = LdapAuthProvider.parse_config(ldap_config)
self.password_providers.append((LdapAuthProvider, parsed_config)) self.password_providers.append((LdapAuthProvider, parsed_config))
providers = config.get("password_providers", []) providers = config.get("password_providers", [])
for provider in providers: for provider in providers:
# We need to import the module, and then pick the class out of # This is for backwards compat when the ldap auth provider resided
# that, so we split based on the last dot. # in this package.
module, clz = provider['module'].rsplit(".", 1) if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
module = importlib.import_module(module) from ldap_auth_provider import LdapAuthProvider
provider_class = getattr(module, clz) provider_class = LdapAuthProvider
else:
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
module, clz = provider['module'].rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
try: try:
provider_config = provider_class.parse_config(provider["config"]) provider_config = provider_class.parse_config(provider["config"])
@ -50,7 +56,7 @@ class PasswordAuthProviderConfig(Config):
def default_config(self, **kwargs): def default_config(self, **kwargs):
return """\ return """\
# password_providers: # password_providers:
# - module: "synapse.util.ldap_auth_provider.LdapAuthProvider" # - module: "ldap_auth_provider.LdapAuthProvider"
# config: # config:
# enabled: true # enabled: true
# uri: "ldap://ldap.example.com:389" # uri: "ldap://ldap.example.com:389"

View File

@ -32,7 +32,6 @@ class RegistrationConfig(Config):
) )
self.registration_shared_secret = config.get("registration_shared_secret") self.registration_shared_secret = config.get("registration_shared_secret")
self.user_creation_max_duration = int(config["user_creation_max_duration"])
self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"] self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
@ -55,11 +54,6 @@ class RegistrationConfig(Config):
# secret, even if registration is otherwise disabled. # secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s" registration_shared_secret: "%(registration_shared_secret)s"
# Sets the expiry for the short term user creation in
# milliseconds. For instance the bellow duration is two weeks
# in milliseconds.
user_creation_max_duration: 1209600000
# Set the number of bcrypt rounds used to generate password hash. # Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash. # Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12. # The default number of rounds is 12.

View File

@ -30,6 +30,11 @@ class ServerConfig(Config):
self.use_frozen_dicts = config.get("use_frozen_dicts", False) self.use_frozen_dicts = config.get("use_frozen_dicts", False)
self.public_baseurl = config.get("public_baseurl") self.public_baseurl = config.get("public_baseurl")
# Whether to send federation traffic out in this process. This only
# applies to some federation traffic, and so shouldn't be used to
# "disable" federation
self.send_federation = config.get("send_federation", True)
if self.public_baseurl is not None: if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/': if self.public_baseurl[-1] != '/':
self.public_baseurl += '/' self.public_baseurl += '/'

View File

@ -16,6 +16,17 @@
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from . import EventBase from . import EventBase
from frozendict import frozendict
import re
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'.
# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as
# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
SPLIT_FIELD_REGEX = re.compile(r'(?<!\\)\.')
def prune_event(event): def prune_event(event):
""" Returns a pruned version of the given event, which removes all keys we """ Returns a pruned version of the given event, which removes all keys we
@ -97,6 +108,83 @@ def prune_event(event):
) )
def _copy_field(src, dst, field):
"""Copy the field in 'src' to 'dst'.
For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
then dst={"foo":{"bar":5}}.
Args:
src(dict): The dict to read from.
dst(dict): The dict to modify.
field(list<str>): List of keys to drill down to in 'src'.
"""
if len(field) == 0: # this should be impossible
return
if len(field) == 1: # common case e.g. 'origin_server_ts'
if field[0] in src:
dst[field[0]] = src[field[0]]
return
# Else is a nested field e.g. 'content.body'
# Pop the last field as that's the key to move across and we need the
# parent dict in order to access the data. Drill down to the right dict.
key_to_move = field.pop(-1)
sub_dict = src
for sub_field in field: # e.g. sub_field => "content"
if sub_field in sub_dict and type(sub_dict[sub_field]) in [dict, frozendict]:
sub_dict = sub_dict[sub_field]
else:
return
if key_to_move not in sub_dict:
return
# Insert the key into the output dictionary, creating nested objects
# as required. We couldn't do this any earlier or else we'd need to delete
# the empty objects if the key didn't exist.
sub_out_dict = dst
for sub_field in field:
sub_out_dict = sub_out_dict.setdefault(sub_field, {})
sub_out_dict[key_to_move] = sub_dict[key_to_move]
def only_fields(dictionary, fields):
"""Return a new dict with only the fields in 'dictionary' which are present
in 'fields'.
If there are no event fields specified then all fields are included.
The entries may include '.' charaters to indicate sub-fields.
So ['content.body'] will include the 'body' field of the 'content' object.
A literal '.' character in a field name may be escaped using a '\'.
Args:
dictionary(dict): The dictionary to read from.
fields(list<str>): A list of fields to copy over. Only shallow refs are
taken.
Returns:
dict: A new dictionary with only the given fields. If fields was empty,
the same dictionary is returned.
"""
if len(fields) == 0:
return dictionary
# for each field, convert it:
# ["content.body.thing\.with\.dots"] => [["content", "body", "thing\.with\.dots"]]
split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields]
# for each element of the output array of arrays:
# remove escaping so we can use the right key names.
split_fields[:] = [
[f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields
]
output = {}
for field_array in split_fields:
_copy_field(dictionary, output, field_array)
return output
def format_event_raw(d): def format_event_raw(d):
return d return d
@ -137,7 +225,7 @@ def format_event_for_client_v2_without_room_id(d):
def serialize_event(e, time_now_ms, as_client_event=True, def serialize_event(e, time_now_ms, as_client_event=True,
event_format=format_event_for_client_v1, event_format=format_event_for_client_v1,
token_id=None): token_id=None, only_event_fields=None):
# FIXME(erikj): To handle the case of presence events and the like # FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase): if not isinstance(e, EventBase):
return e return e
@ -164,6 +252,12 @@ def serialize_event(e, time_now_ms, as_client_event=True,
d["unsigned"]["transaction_id"] = txn_id d["unsigned"]["transaction_id"] = txn_id
if as_client_event: if as_client_event:
return event_format(d) d = event_format(d)
else:
return d if only_event_fields:
if (not isinstance(only_event_fields, list) or
not all(isinstance(f, basestring) for f in only_event_fields)):
raise TypeError("only_event_fields must be a list of strings")
d = only_fields(d, only_event_fields)
return d

View File

@ -17,10 +17,9 @@
""" """
from .replication import ReplicationLayer from .replication import ReplicationLayer
from .transport.client import TransportLayerClient
def initialize_http_replication(homeserver): def initialize_http_replication(hs):
transport = TransportLayerClient(homeserver) transport = hs.get_federation_transport_client()
return ReplicationLayer(homeserver, transport) return ReplicationLayer(hs, transport)

View File

@ -18,7 +18,6 @@ from twisted.internet import defer
from .federation_base import FederationBase from .federation_base import FederationBase
from synapse.api.constants import Membership from synapse.api.constants import Membership
from .units import Edu
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError, CodeMessageException, HttpResponseException, SynapseError,
@ -45,10 +44,6 @@ logger = logging.getLogger(__name__)
# synapse.federation.federation_client is a silly name # synapse.federation.federation_client is a silly name
metrics = synapse.metrics.get_metrics_for("synapse.federation.client") metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
sent_pdus_destination_dist = metrics.register_distribution("sent_pdu_destinations")
sent_edus_counter = metrics.register_counter("sent_edus")
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"]) sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
@ -92,63 +87,6 @@ class FederationClient(FederationBase):
self._get_pdu_cache.start() self._get_pdu_cache.start()
@log_function
def send_pdu(self, pdu, destinations):
"""Informs the replication layer about a new PDU generated within the
home server that should be transmitted to others.
TODO: Figure out when we should actually resolve the deferred.
Args:
pdu (Pdu): The new Pdu.
Returns:
Deferred: Completes when we have successfully processed the PDU
and replicated it to any interested remote home servers.
"""
order = self._order
self._order += 1
sent_pdus_destination_dist.inc_by(len(destinations))
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# TODO, add errback, etc.
self._transaction_queue.enqueue_pdu(pdu, destinations, order)
logger.debug(
"[%s] transaction_layer.enqueue_pdu... done",
pdu.event_id
)
def send_presence(self, destination, states):
if destination != self.server_name:
self._transaction_queue.enqueue_presence(destination, states)
@log_function
def send_edu(self, destination, edu_type, content, key=None):
edu = Edu(
origin=self.server_name,
destination=destination,
edu_type=edu_type,
content=content,
)
sent_edus_counter.inc()
self._transaction_queue.enqueue_edu(edu, key=key)
@log_function
def send_device_messages(self, destination):
"""Sends the device messages in the local database to the remote
destination"""
self._transaction_queue.enqueue_device_messages(destination)
@log_function
def send_failure(self, failure, destination):
self._transaction_queue.enqueue_failure(failure, destination)
return defer.succeed(None)
@log_function @log_function
def make_query(self, destination, query_type, args, def make_query(self, destination, query_type, args,
retry_on_dns_fail=False): retry_on_dns_fail=False):
@ -717,12 +655,15 @@ class FederationClient(FederationBase):
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
def get_public_rooms(self, destination, limit=None, since_token=None, def get_public_rooms(self, destination, limit=None, since_token=None,
search_filter=None): search_filter=None, include_all_networks=False,
third_party_instance_id=None):
if destination == self.server_name: if destination == self.server_name:
return return
return self.transport_layer.get_public_rooms( return self.transport_layer.get_public_rooms(
destination, limit, since_token, search_filter destination, limit, since_token, search_filter,
include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -20,8 +20,6 @@ a given transport.
from .federation_client import FederationClient from .federation_client import FederationClient
from .federation_server import FederationServer from .federation_server import FederationServer
from .transaction_queue import TransactionQueue
from .persistence import TransactionActions from .persistence import TransactionActions
import logging import logging
@ -66,9 +64,6 @@ class ReplicationLayer(FederationClient, FederationServer):
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.transaction_actions = TransactionActions(self.store) self.transaction_actions = TransactionActions(self.store)
self._transaction_queue = TransactionQueue(hs, transport_layer)
self._order = 0
self.hs = hs self.hs = hs

View File

@ -0,0 +1,298 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A federation sender that forwards things to be sent across replication to
a worker process.
It assumes there is a single worker process feeding off of it.
Each row in the replication stream consists of a type and some json, where the
types indicate whether they are presence, or edus, etc.
Ephemeral or non-event data are queued up in-memory. When the worker requests
updates since a particular point, all in-memory data since before that point is
dropped. We also expire things in the queue after 5 minutes, to ensure that a
dead worker doesn't cause the queues to grow limitlessly.
Events are replicated via a separate events stream.
"""
from .units import Edu
from synapse.util.metrics import Measure
import synapse.metrics
from blist import sorteddict
import ujson
metrics = synapse.metrics.get_metrics_for(__name__)
PRESENCE_TYPE = "p"
KEYED_EDU_TYPE = "k"
EDU_TYPE = "e"
FAILURE_TYPE = "f"
DEVICE_MESSAGE_TYPE = "d"
class FederationRemoteSendQueue(object):
"""A drop in replacement for TransactionQueue"""
def __init__(self, hs):
self.server_name = hs.hostname
self.clock = hs.get_clock()
self.presence_map = {}
self.presence_changed = sorteddict()
self.keyed_edu = {}
self.keyed_edu_changed = sorteddict()
self.edus = sorteddict()
self.failures = sorteddict()
self.device_messages = sorteddict()
self.pos = 1
self.pos_time = sorteddict()
# EVERYTHING IS SAD. In particular, python only makes new scopes when
# we make a new function, so we need to make a new function so the inner
# lambda binds to the queue rather than to the name of the queue which
# changes. ARGH.
def register(name, queue):
metrics.register_callback(
queue_name + "_size",
lambda: len(queue),
)
for queue_name in [
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
"edus", "failures", "device_messages", "pos_time",
]:
register(queue_name, getattr(self, queue_name))
self.clock.looping_call(self._clear_queue, 30 * 1000)
def _next_pos(self):
pos = self.pos
self.pos += 1
self.pos_time[self.clock.time_msec()] = pos
return pos
def _clear_queue(self):
"""Clear the queues for anything older than N minutes"""
FIVE_MINUTES_AGO = 5 * 60 * 1000
now = self.clock.time_msec()
keys = self.pos_time.keys()
time = keys.bisect_left(now - FIVE_MINUTES_AGO)
if not keys[:time]:
return
position_to_delete = max(keys[:time])
for key in keys[:time]:
del self.pos_time[key]
self._clear_queue_before_pos(position_to_delete)
def _clear_queue_before_pos(self, position_to_delete):
"""Clear all the queues from before a given position"""
with Measure(self.clock, "send_queue._clear"):
# Delete things out of presence maps
keys = self.presence_changed.keys()
i = keys.bisect_left(position_to_delete)
for key in keys[:i]:
del self.presence_changed[key]
user_ids = set(
user_id for uids in self.presence_changed.values() for _, user_id in uids
)
to_del = [
user_id for user_id in self.presence_map if user_id not in user_ids
]
for user_id in to_del:
del self.presence_map[user_id]
# Delete things out of keyed edus
keys = self.keyed_edu_changed.keys()
i = keys.bisect_left(position_to_delete)
for key in keys[:i]:
del self.keyed_edu_changed[key]
live_keys = set()
for edu_key in self.keyed_edu_changed.values():
live_keys.add(edu_key)
to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
for edu_key in to_del:
del self.keyed_edu[edu_key]
# Delete things out of edu map
keys = self.edus.keys()
i = keys.bisect_left(position_to_delete)
for key in keys[:i]:
del self.edus[key]
# Delete things out of failure map
keys = self.failures.keys()
i = keys.bisect_left(position_to_delete)
for key in keys[:i]:
del self.failures[key]
# Delete things out of device map
keys = self.device_messages.keys()
i = keys.bisect_left(position_to_delete)
for key in keys[:i]:
del self.device_messages[key]
def notify_new_events(self, current_id):
"""As per TransactionQueue"""
# We don't need to replicate this as it gets sent down a different
# stream.
pass
def send_edu(self, destination, edu_type, content, key=None):
"""As per TransactionQueue"""
pos = self._next_pos()
edu = Edu(
origin=self.server_name,
destination=destination,
edu_type=edu_type,
content=content,
)
if key:
assert isinstance(key, tuple)
self.keyed_edu[(destination, key)] = edu
self.keyed_edu_changed[pos] = (destination, key)
else:
self.edus[pos] = edu
def send_presence(self, destination, states):
"""As per TransactionQueue"""
pos = self._next_pos()
self.presence_map.update({
state.user_id: state
for state in states
})
self.presence_changed[pos] = [
(destination, state.user_id) for state in states
]
def send_failure(self, failure, destination):
"""As per TransactionQueue"""
pos = self._next_pos()
self.failures[pos] = (destination, str(failure))
def send_device_messages(self, destination):
"""As per TransactionQueue"""
pos = self._next_pos()
self.device_messages[pos] = destination
def get_current_token(self):
return self.pos - 1
def get_replication_rows(self, token, limit, federation_ack=None):
"""
Args:
token (int)
limit (int)
federation_ack (int): Optional. The position where the worker is
explicitly acknowledged it has handled. Allows us to drop
data from before that point
"""
# TODO: Handle limit.
# To handle restarts where we wrap around
if token > self.pos:
token = -1
rows = []
# There should be only one reader, so lets delete everything its
# acknowledged its seen.
if federation_ack:
self._clear_queue_before_pos(federation_ack)
# Fetch changed presence
keys = self.presence_changed.keys()
i = keys.bisect_right(token)
dest_user_ids = set(
(pos, dest_user_id)
for pos in keys[i:]
for dest_user_id in self.presence_changed[pos]
)
for (key, (dest, user_id)) in dest_user_ids:
rows.append((key, PRESENCE_TYPE, ujson.dumps({
"destination": dest,
"state": self.presence_map[user_id].as_dict(),
})))
# Fetch changes keyed edus
keys = self.keyed_edu_changed.keys()
i = keys.bisect_right(token)
keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:])
for (pos, (destination, edu_key)) in keyed_edus:
rows.append(
(pos, KEYED_EDU_TYPE, ujson.dumps({
"key": edu_key,
"edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(),
}))
)
# Fetch changed edus
keys = self.edus.keys()
i = keys.bisect_right(token)
edus = set((k, self.edus[k]) for k in keys[i:])
for (pos, edu) in edus:
rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict())))
# Fetch changed failures
keys = self.failures.keys()
i = keys.bisect_right(token)
failures = set((k, self.failures[k]) for k in keys[i:])
for (pos, (destination, failure)) in failures:
rows.append((pos, FAILURE_TYPE, ujson.dumps({
"destination": destination,
"failure": failure,
})))
# Fetch changed device messages
keys = self.device_messages.keys()
i = keys.bisect_right(token)
device_messages = set((k, self.device_messages[k]) for k in keys[i:])
for (pos, destination) in device_messages:
rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({
"destination": destination,
})))
# Sort rows based on pos
rows.sort()
return rows

View File

@ -19,6 +19,7 @@ from twisted.internet import defer
from .persistence import TransactionActions from .persistence import TransactionActions
from .units import Transaction, Edu from .units import Transaction, Edu
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
@ -26,6 +27,7 @@ from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination, get_retry_limiter, NotRetryingDestination,
) )
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.types import get_domain_from_id
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
import synapse.metrics import synapse.metrics
@ -36,6 +38,12 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__) metrics = synapse.metrics.get_metrics_for(__name__)
client_metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
sent_pdus_destination_dist = client_metrics.register_distribution(
"sent_pdu_destinations"
)
sent_edus_counter = client_metrics.register_counter("sent_edus")
class TransactionQueue(object): class TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at """This class makes sure we only have one transaction in flight at
@ -44,13 +52,14 @@ class TransactionQueue(object):
It batches pending PDUs into single transactions. It batches pending PDUs into single transactions.
""" """
def __init__(self, hs, transport_layer): def __init__(self, hs):
self.server_name = hs.hostname self.server_name = hs.hostname
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.transaction_actions = TransactionActions(self.store) self.transaction_actions = TransactionActions(self.store)
self.transport_layer = transport_layer self.transport_layer = hs.get_federation_transport_client()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -95,6 +104,11 @@ class TransactionQueue(object):
# HACK to get unique tx id # HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec()) self._next_txn_id = int(self.clock.time_msec())
self._order = 1
self._is_processing = False
self._last_poked_id = -1
def can_send_to(self, destination): def can_send_to(self, destination):
"""Can we send messages to the given server? """Can we send messages to the given server?
@ -115,11 +129,61 @@ class TransactionQueue(object):
else: else:
return not destination.startswith("localhost") return not destination.startswith("localhost")
def enqueue_pdu(self, pdu, destinations, order): @defer.inlineCallbacks
def notify_new_events(self, current_id):
"""This gets called when we have some new events we might want to
send out to other servers.
"""
self._last_poked_id = max(current_id, self._last_poked_id)
if self._is_processing:
return
try:
self._is_processing = True
while True:
last_token = yield self.store.get_federation_out_pos("events")
next_token, events = yield self.store.get_all_new_events_stream(
last_token, self._last_poked_id, limit=20,
)
logger.debug("Handling %s -> %s", last_token, next_token)
if not events and next_token >= self._last_poked_id:
break
for event in events:
users_in_room = yield self.state.get_current_user_in_room(
event.room_id, latest_event_ids=[event.event_id],
)
destinations = set(
get_domain_from_id(user_id) for user_id in users_in_room
)
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN:
destinations.add(get_domain_from_id(event.state_key))
logger.debug("Sending %s to %r", event, destinations)
self._send_pdu(event, destinations)
yield self.store.update_federation_out_pos(
"events", next_token
)
finally:
self._is_processing = False
def _send_pdu(self, pdu, destinations):
# We loop through all destinations to see whether we already have # We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus # a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later. # table and we'll get back to it later.
order = self._order
self._order += 1
destinations = set(destinations) destinations = set(destinations)
destinations = set( destinations = set(
dest for dest in destinations if self.can_send_to(dest) dest for dest in destinations if self.can_send_to(dest)
@ -130,6 +194,8 @@ class TransactionQueue(object):
if not destinations: if not destinations:
return return
sent_pdus_destination_dist.inc_by(len(destinations))
for destination in destinations: for destination in destinations:
self.pending_pdus_by_dest.setdefault(destination, []).append( self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, order) (pdu, order)
@ -139,7 +205,10 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination self._attempt_new_transaction, destination
) )
def enqueue_presence(self, destination, states): def send_presence(self, destination, states):
if not self.can_send_to(destination):
return
self.pending_presence_by_dest.setdefault(destination, {}).update({ self.pending_presence_by_dest.setdefault(destination, {}).update({
state.user_id: state for state in states state.user_id: state for state in states
}) })
@ -148,12 +217,19 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination self._attempt_new_transaction, destination
) )
def enqueue_edu(self, edu, key=None): def send_edu(self, destination, edu_type, content, key=None):
destination = edu.destination edu = Edu(
origin=self.server_name,
destination=destination,
edu_type=edu_type,
content=content,
)
if not self.can_send_to(destination): if not self.can_send_to(destination):
return return
sent_edus_counter.inc()
if key: if key:
self.pending_edus_keyed_by_dest.setdefault( self.pending_edus_keyed_by_dest.setdefault(
destination, {} destination, {}
@ -165,7 +241,7 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination self._attempt_new_transaction, destination
) )
def enqueue_failure(self, failure, destination): def send_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost": if destination == self.server_name or destination == "localhost":
return return
@ -180,7 +256,7 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination self._attempt_new_transaction, destination
) )
def enqueue_device_messages(self, destination): def send_device_messages(self, destination):
if destination == self.server_name or destination == "localhost": if destination == self.server_name or destination == "localhost":
return return
@ -191,6 +267,9 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination self._attempt_new_transaction, destination
) )
def get_current_token(self):
return 0
@defer.inlineCallbacks @defer.inlineCallbacks
def _attempt_new_transaction(self, destination): def _attempt_new_transaction(self, destination):
# list of (pending_pdu, deferred, order) # list of (pending_pdu, deferred, order)
@ -383,6 +462,13 @@ class TransactionQueue(object):
code = e.code code = e.code
response = e.response response = e.response
if e.code == 429 or 500 <= e.code:
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
)
raise e
logger.info( logger.info(
"TX [%s] {%s} got %d response", "TX [%s] {%s} got %d response",
destination, txn_id, code destination, txn_id, code

View File

@ -249,10 +249,15 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_public_rooms(self, remote_server, limit, since_token, def get_public_rooms(self, remote_server, limit, since_token,
search_filter=None): search_filter=None, include_all_networks=False,
third_party_instance_id=None):
path = PREFIX + "/publicRooms" path = PREFIX + "/publicRooms"
args = {} args = {
"include_all_networks": "true" if include_all_networks else "false",
}
if third_party_instance_id:
args["third_party_instance_id"] = third_party_instance_id,
if limit: if limit:
args["limit"] = [str(limit)] args["limit"] = [str(limit)]
if since_token: if since_token:

View File

@ -20,9 +20,11 @@ from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import ( from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
parse_boolean_from_args,
) )
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from synapse.types import ThirdPartyInstanceID
import functools import functools
import logging import logging
@ -558,8 +560,23 @@ class PublicRoomList(BaseFederationServlet):
def on_GET(self, origin, content, query): def on_GET(self, origin, content, query):
limit = parse_integer_from_args(query, "limit", 0) limit = parse_integer_from_args(query, "limit", 0)
since_token = parse_string_from_args(query, "since", None) since_token = parse_string_from_args(query, "since", None)
include_all_networks = parse_boolean_from_args(
query, "include_all_networks", False
)
third_party_instance_id = parse_string_from_args(
query, "third_party_instance_id", None
)
if include_all_networks:
network_tuple = None
elif third_party_instance_id:
network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
else:
network_tuple = ThirdPartyInstanceID(None, None)
data = yield self.room_list_handler.get_local_public_room_list( data = yield self.room_list_handler.get_local_public_room_list(
limit, since_token limit, since_token,
network_tuple=network_tuple
) )
defer.returnValue((200, data)) defer.returnValue((200, data))

View File

@ -24,7 +24,6 @@ from .profile import ProfileHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
from .admin import AdminHandler from .admin import AdminHandler
from .identity import IdentityHandler from .identity import IdentityHandler
from .receipts import ReceiptsHandler
from .search import SearchHandler from .search import SearchHandler
@ -56,7 +55,6 @@ class Handlers(object):
self.profile_handler = ProfileHandler(hs) self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)
self.receipts_handler = ReceiptsHandler(hs)
self.identity_handler = IdentityHandler(hs) self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs) self.search_handler = SearchHandler(hs)
self.room_context_handler = RoomContextHandler(hs) self.room_context_handler = RoomContextHandler(hs)

View File

@ -61,6 +61,8 @@ class AuthHandler(BaseHandler):
for module, config in hs.config.password_providers for module, config in hs.config.password_providers
] ]
logger.info("Extra password_providers: %r", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
@ -160,7 +162,15 @@ class AuthHandler(BaseHandler):
for f in flows: for f in flows:
if len(set(f) - set(creds.keys())) == 0: if len(set(f) - set(creds.keys())) == 0:
logger.info("Auth completed with creds: %r", creds) # it's very useful to know what args are stored, but this can
# include the password in the case of registering, so only log
# the keys (confusingly, clientdict may contain a password
# param, creds is just what the user authed as for UI auth
# and is not sensitive).
logger.info(
"Auth completed with creds: %r. Client dict has keys: %r",
creds, clientdict.keys()
)
defer.returnValue((True, creds, clientdict, session['id'])) defer.returnValue((True, creds, clientdict, session['id']))
ret = self._auth_dict_for_flows(flows, session) ret = self._auth_dict_for_flows(flows, session)
@ -378,12 +388,10 @@ class AuthHandler(BaseHandler):
return self._check_password(user_id, password) return self._check_password(user_id, password)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id, device_id=None, def get_access_token_for_user_id(self, user_id, device_id=None,
initial_display_name=None): initial_display_name=None):
""" """
Gets login tuple for the user with the given user ID. Creates a new access token for the user with the given user ID.
Creates a new access/refresh token for the user.
The user is assumed to have been authenticated by some other The user is assumed to have been authenticated by some other
machanism (e.g. CAS), and the user_id converted to the canonical case. machanism (e.g. CAS), and the user_id converted to the canonical case.
@ -398,16 +406,13 @@ class AuthHandler(BaseHandler):
initial_display_name (str): display name to associate with the initial_display_name (str): display name to associate with the
device if it needs re-registering device if it needs re-registering
Returns: Returns:
A tuple of:
The access token for the user's session. The access token for the user's session.
The refresh token for the user's session.
Raises: Raises:
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
logger.info("Logging in user %s on device %s", user_id, device_id) logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id)
refresh_token = yield self.issue_refresh_token(user_id, device_id)
# the device *should* have been registered before we got here; however, # the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we # it's possible we raced against a DELETE operation. The thing we
@ -418,7 +423,7 @@ class AuthHandler(BaseHandler):
user_id, device_id, initial_display_name user_id, device_id, initial_display_name
) )
defer.returnValue((access_token, refresh_token)) defer.returnValue(access_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_user_exists(self, user_id): def check_user_exists(self, user_id):
@ -529,35 +534,19 @@ class AuthHandler(BaseHandler):
device_id) device_id)
defer.returnValue(access_token) defer.returnValue(access_token)
@defer.inlineCallbacks def generate_access_token(self, user_id, extra_caveats=None):
def issue_refresh_token(self, user_id, device_id=None):
refresh_token = self.generate_refresh_token(user_id)
yield self.store.add_refresh_token_to_user(user_id, refresh_token,
device_id)
defer.returnValue(refresh_token)
def generate_access_token(self, user_id, extra_caveats=None,
duration_in_ms=(60 * 60 * 1000)):
extra_caveats = extra_caveats or [] extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec() # Include a nonce, to make sure that each login gets a different
expiry = now + duration_in_ms # access token.
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("nonce = %s" % (
stringutils.random_string_with_symbols(16),
))
for caveat in extra_caveats: for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat) macaroon.add_first_party_caveat(caveat)
return macaroon.serialize() return macaroon.serialize()
def generate_refresh_token(self, user_id):
m = self._generate_base_macaroon(user_id)
m.add_first_party_caveat("type = refresh")
# Important to add a nonce, because otherwise every refresh token for a
# user will be the same.
m.add_first_party_caveat("nonce = %s" % (
stringutils.random_string_with_symbols(16),
))
return m.serialize()
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login") macaroon.add_first_party_caveat("type = login")

View File

@ -34,9 +34,9 @@ class DeviceMessageHandler(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.federation = hs.get_replication_layer() self.federation = hs.get_federation_sender()
self.federation.register_edu_handler( hs.get_replication_layer().register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu "m.direct_to_device", self.on_direct_to_device_edu
) )

View File

@ -339,3 +339,22 @@ class DirectoryHandler(BaseHandler):
yield self.auth.check_can_change_room_list(room_id, requester.user) yield self.auth.check_can_change_room_list(room_id, requester.user)
yield self.store.set_room_is_public(room_id, visibility == "public") yield self.store.set_room_is_public(room_id, visibility == "public")
@defer.inlineCallbacks
def edit_published_appservice_room_list(self, appservice_id, network_id,
room_id, visibility):
"""Add or remove a room from the appservice/network specific public
room list.
Args:
appservice_id (str): ID of the appservice that owns the list
network_id (str): The ID of the network the list is associated with
room_id (str)
visibility (str): either "public" or "private"
"""
if visibility not in ["public", "private"]:
raise SynapseError(400, "Invalid visibility setting")
yield self.store.set_room_is_public_appservice(
room_id, appservice_id, network_id, visibility == "public"
)

View File

@ -111,6 +111,11 @@ class E2eKeysHandler(object):
failures[destination] = { failures[destination] = {
"status": 503, "message": "Not ready for retry", "status": 503, "message": "Not ready for retry",
} }
except Exception as e:
# include ConnectionRefused and other errors
failures[destination] = {
"status": 503, "message": e.message
}
yield preserve_context_over_deferred(defer.gatherResults([ yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(do_remote_query)(destination) preserve_fn(do_remote_query)(destination)
@ -222,6 +227,11 @@ class E2eKeysHandler(object):
failures[destination] = { failures[destination] = {
"status": 503, "message": "Not ready for retry", "status": 503, "message": "Not ready for retry",
} }
except Exception as e:
# include ConnectionRefused and other errors
failures[destination] = {
"status": 503, "message": e.message
}
yield preserve_context_over_deferred(defer.gatherResults([ yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(claim_client_keys)(destination) preserve_fn(claim_client_keys)(destination)

View File

@ -80,22 +80,6 @@ class FederationHandler(BaseHandler):
# When joining a room we need to queue any events for that room up # When joining a room we need to queue any events for that room up
self.room_queues = {} self.room_queues = {}
def handle_new_event(self, event, destinations):
""" Takes in an event from the client to server side, that has already
been authed and handled by the state module, and sends it to any
remote home servers that may be interested.
Args:
event: The event to send
destinations: A list of destinations to send it to
Returns:
Deferred: Resolved when it has successfully been queued for
processing.
"""
return self.replication_layer.send_pdu(event, destinations)
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
@ -268,9 +252,12 @@ class FederationHandler(BaseHandler):
except: except:
return False return False
# Parses mapping `event_id -> (type, state_key) -> state event_id`
# to get all state ids that we're interested in.
event_map = yield self.store.get_events([ event_map = yield self.store.get_events([
e_id for key_to_eid in event_to_state_ids.values() e_id
for key, e_id in key_to_eid for key_to_eid in event_to_state_ids.values()
for key, e_id in key_to_eid.items()
if key[0] != EventTypes.Member or check_match(key[1]) if key[0] != EventTypes.Member or check_match(key[1])
]) ])
@ -830,25 +817,6 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id) yield user_joined_room(self.distributor, user, event.room_id)
new_pdu = event
users_in_room = yield self.store.get_joined_users_from_context(event, context)
destinations = set(
get_domain_from_id(user_id) for user_id in users_in_room
if not self.hs.is_mine_id(user_id)
)
destinations.discard(origin)
logger.debug(
"on_send_join_request: Sending event: %s, signatures: %s",
event.event_id,
event.signatures,
)
self.replication_layer.send_pdu(new_pdu, destinations)
state_ids = context.prev_state_ids.values() state_ids = context.prev_state_ids.values()
auth_chain = yield self.store.get_auth_chain(set( auth_chain = yield self.store.get_auth_chain(set(
[event.event_id] + state_ids [event.event_id] + state_ids
@ -1055,24 +1023,6 @@ class FederationHandler(BaseHandler):
event, event_stream_id, max_stream_id, extra_users=extra_users event, event_stream_id, max_stream_id, extra_users=extra_users
) )
new_pdu = event
users_in_room = yield self.store.get_joined_users_from_context(event, context)
destinations = set(
get_domain_from_id(user_id) for user_id in users_in_room
if not self.hs.is_mine_id(user_id)
)
destinations.discard(origin)
logger.debug(
"on_send_leave_request: Sending event: %s, signatures: %s",
event.event_id,
event.signatures,
)
self.replication_layer.send_pdu(new_pdu, destinations)
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -372,11 +372,12 @@ class InitialSyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_receipts(): def get_receipts():
receipts_handler = self.hs.get_handlers().receipts_handler receipts = yield self.store.get_linearized_receipts_for_room(
receipts = yield receipts_handler.get_receipts_for_room(
room_id, room_id,
now_token.receipt_key to_key=now_token.receipt_key,
) )
if not receipts:
receipts = []
defer.returnValue(receipts) defer.returnValue(receipts)
presence, receipts, (messages, token) = yield defer.gatherResults( presence, receipts, (messages, token) = yield defer.gatherResults(

View File

@ -22,9 +22,9 @@ from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.push.action_generator import ActionGenerator from synapse.push.action_generator import ActionGenerator
from synapse.types import ( from synapse.types import (
UserID, RoomAlias, RoomStreamToken, get_domain_from_id UserID, RoomAlias, RoomStreamToken,
) )
from synapse.util.async import run_on_reactor, ReadWriteLock from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -50,6 +50,10 @@ class MessageHandler(BaseHandler):
self.pagination_lock = ReadWriteLock() self.pagination_lock = ReadWriteLock()
# We arbitrarily limit concurrent event creation for a room to 5.
# This is to stop us from diverging history *too* much.
self.limiter = Limiter(max_count=5)
@defer.inlineCallbacks @defer.inlineCallbacks
def purge_history(self, room_id, event_id): def purge_history(self, room_id, event_id):
event = yield self.store.get_event(event_id) event = yield self.store.get_event(event_id)
@ -191,36 +195,38 @@ class MessageHandler(BaseHandler):
""" """
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
self.validator.validate_new(builder) with (yield self.limiter.queue(builder.room_id)):
self.validator.validate_new(builder)
if builder.type == EventTypes.Member: if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None) membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key) target = UserID.from_string(builder.state_key)
if membership in {Membership.JOIN, Membership.INVITE}: if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one. # If event doesn't include a display name, add one.
profile = self.hs.get_handlers().profile_handler profile = self.hs.get_handlers().profile_handler
content = builder.content content = builder.content
try: try:
content["displayname"] = yield profile.get_displayname(target) content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target) content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e: except Exception as e:
logger.info( logger.info(
"Failed to get profile information for %r: %s", "Failed to get profile information for %r: %s",
target, e target, e
) )
if token_id is not None: if token_id is not None:
builder.internal_metadata.token_id = token_id builder.internal_metadata.token_id = token_id
if txn_id is not None: if txn_id is not None:
builder.internal_metadata.txn_id = txn_id builder.internal_metadata.txn_id = txn_id
event, context = yield self._create_new_client_event(
builder=builder,
prev_event_ids=prev_event_ids,
)
event, context = yield self._create_new_client_event(
builder=builder,
prev_event_ids=prev_event_ids,
)
defer.returnValue((event, context)) defer.returnValue((event, context))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -599,13 +605,6 @@ class MessageHandler(BaseHandler):
event_stream_id, max_stream_id event_stream_id, max_stream_id
) )
users_in_room = yield self.store.get_joined_users_from_context(event, context)
destinations = [
get_domain_from_id(user_id) for user_id in users_in_room
if not self.hs.is_mine_id(user_id)
]
@defer.inlineCallbacks @defer.inlineCallbacks
def _notify(): def _notify():
yield run_on_reactor() yield run_on_reactor()
@ -618,7 +617,3 @@ class MessageHandler(BaseHandler):
# If invite, remove room_state from unsigned before sending. # If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None) event.unsigned.pop("invite_room_state", None)
preserve_fn(federation_handler.handle_new_event)(
event, destinations=destinations,
)

View File

@ -91,28 +91,29 @@ class PresenceHandler(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.wheel_timer = WheelTimer() self.wheel_timer = WheelTimer()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.federation = hs.get_replication_layer() self.replication = hs.get_replication_layer()
self.federation = hs.get_federation_sender()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.federation.register_edu_handler( self.replication.register_edu_handler(
"m.presence", self.incoming_presence "m.presence", self.incoming_presence
) )
self.federation.register_edu_handler( self.replication.register_edu_handler(
"m.presence_invite", "m.presence_invite",
lambda origin, content: self.invite_presence( lambda origin, content: self.invite_presence(
observed_user=UserID.from_string(content["observed_user"]), observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]), observer_user=UserID.from_string(content["observer_user"]),
) )
) )
self.federation.register_edu_handler( self.replication.register_edu_handler(
"m.presence_accept", "m.presence_accept",
lambda origin, content: self.accept_presence( lambda origin, content: self.accept_presence(
observed_user=UserID.from_string(content["observed_user"]), observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]), observer_user=UserID.from_string(content["observer_user"]),
) )
) )
self.federation.register_edu_handler( self.replication.register_edu_handler(
"m.presence_deny", "m.presence_deny",
lambda origin, content: self.deny_presence( lambda origin, content: self.deny_presence(
observed_user=UserID.from_string(content["observed_user"]), observed_user=UserID.from_string(content["observed_user"]),

View File

@ -33,8 +33,8 @@ class ReceiptsHandler(BaseHandler):
self.server_name = hs.config.server_name self.server_name = hs.config.server_name
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self.federation = hs.get_replication_layer() self.federation = hs.get_federation_sender()
self.federation.register_edu_handler( hs.get_replication_layer().register_edu_handler(
"m.receipt", self._received_remote_receipt "m.receipt", self._received_remote_receipt
) )
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
@ -100,7 +100,7 @@ class ReceiptsHandler(BaseHandler):
if not res: if not res:
# res will be None if this read receipt is 'old' # res will be None if this read receipt is 'old'
defer.returnValue(False) continue
stream_id, max_persisted_id = res stream_id, max_persisted_id = res
@ -109,6 +109,10 @@ class ReceiptsHandler(BaseHandler):
if max_batch_id is None or max_persisted_id > max_batch_id: if max_batch_id is None or max_persisted_id > max_batch_id:
max_batch_id = max_persisted_id max_batch_id = max_persisted_id
if min_batch_id is None:
# no new receipts
defer.returnValue(False)
affected_room_ids = list(set([r["room_id"] for r in receipts])) affected_room_ids = list(set([r["room_id"] for r in receipts]))
with PreserveLoggingContext(): with PreserveLoggingContext():

View File

@ -81,7 +81,7 @@ class RegistrationHandler(BaseHandler):
"User ID already taken.", "User ID already taken.",
errcode=Codes.USER_IN_USE, errcode=Codes.USER_IN_USE,
) )
user_data = yield self.auth.get_user_from_macaroon(guest_access_token) user_data = yield self.auth.get_user_by_access_token(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart: if not user_data["is_guest"] or user_data["user"].localpart != localpart:
raise AuthError( raise AuthError(
403, 403,
@ -369,7 +369,7 @@ class RegistrationHandler(BaseHandler):
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_or_create_user(self, requester, localpart, displayname, duration_in_ms, def get_or_create_user(self, requester, localpart, displayname,
password_hash=None): password_hash=None):
"""Creates a new user if the user does not exist, """Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one. else revokes all previous access tokens and generates a new one.
@ -399,8 +399,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
token = self.auth_handler().generate_access_token( token = self.auth_handler().generate_access_token(user_id)
user_id, None, duration_in_ms)
if need_register: if need_register:
yield self.store.register( yield self.store.register(

View File

@ -44,16 +44,19 @@ class RoomCreationHandler(BaseHandler):
"join_rules": JoinRules.INVITE, "join_rules": JoinRules.INVITE,
"history_visibility": "shared", "history_visibility": "shared",
"original_invitees_have_ops": False, "original_invitees_have_ops": False,
"guest_can_join": True,
}, },
RoomCreationPreset.TRUSTED_PRIVATE_CHAT: { RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE, "join_rules": JoinRules.INVITE,
"history_visibility": "shared", "history_visibility": "shared",
"original_invitees_have_ops": True, "original_invitees_have_ops": True,
"guest_can_join": True,
}, },
RoomCreationPreset.PUBLIC_CHAT: { RoomCreationPreset.PUBLIC_CHAT: {
"join_rules": JoinRules.PUBLIC, "join_rules": JoinRules.PUBLIC,
"history_visibility": "shared", "history_visibility": "shared",
"original_invitees_have_ops": False, "original_invitees_have_ops": False,
"guest_can_join": False,
}, },
} }
@ -336,6 +339,13 @@ class RoomCreationHandler(BaseHandler):
content={"history_visibility": config["history_visibility"]} content={"history_visibility": config["history_visibility"]}
) )
if config["guest_can_join"]:
if (EventTypes.GuestAccess, '') not in initial_state:
yield send(
etype=EventTypes.GuestAccess,
content={"guest_access": "can_join"}
)
for (etype, state_key), content in initial_state.items(): for (etype, state_key), content in initial_state.items():
yield send( yield send(
etype=etype, etype=etype,

View File

@ -22,6 +22,7 @@ from synapse.api.constants import (
) )
from synapse.util.async import concurrently_execute from synapse.util.async import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.types import ThirdPartyInstanceID
from collections import namedtuple from collections import namedtuple
from unpaddedbase64 import encode_base64, decode_base64 from unpaddedbase64 import encode_base64, decode_base64
@ -34,6 +35,10 @@ logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
# This is used to indicate we should only return rooms published to the main list.
EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler): class RoomListHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(RoomListHandler, self).__init__(hs) super(RoomListHandler, self).__init__(hs)
@ -41,22 +46,43 @@ class RoomListHandler(BaseHandler):
self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000) self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
def get_local_public_room_list(self, limit=None, since_token=None, def get_local_public_room_list(self, limit=None, since_token=None,
search_filter=None): search_filter=None,
if search_filter: network_tuple=EMTPY_THIRD_PARTY_ID,):
# We explicitly don't bother caching searches. """Generate a local public room list.
return self._get_public_room_list(limit, since_token, search_filter)
There are multiple different lists: the main one plus one per third
party network. A client can ask for a specific list or to return all.
Args:
limit (int)
since_token (str)
search_filter (dict)
network_tuple (ThirdPartyInstanceID): Which public list to use.
This can be (None, None) to indicate the main list, or a particular
appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists.
"""
if search_filter or (network_tuple and network_tuple.appservice_id is not None):
# We explicitly don't bother caching searches or requests for
# appservice specific lists.
return self._get_public_room_list(
limit, since_token, search_filter, network_tuple=network_tuple,
)
result = self.response_cache.get((limit, since_token)) result = self.response_cache.get((limit, since_token))
if not result: if not result:
result = self.response_cache.set( result = self.response_cache.set(
(limit, since_token), (limit, since_token),
self._get_public_room_list(limit, since_token) self._get_public_room_list(
limit, since_token, network_tuple=network_tuple
)
) )
return result return result
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None, def _get_public_room_list(self, limit=None, since_token=None,
search_filter=None): search_filter=None,
network_tuple=EMTPY_THIRD_PARTY_ID,):
if since_token and since_token != "END": if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token) since_token = RoomListNextBatch.from_token(since_token)
else: else:
@ -73,14 +99,15 @@ class RoomListHandler(BaseHandler):
current_public_id = yield self.store.get_current_public_room_stream_id() current_public_id = yield self.store.get_current_public_room_stream_id()
public_room_stream_id = since_token.public_room_stream_id public_room_stream_id = since_token.public_room_stream_id
newly_visible, newly_unpublished = yield self.store.get_public_room_changes( newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
public_room_stream_id, current_public_id public_room_stream_id, current_public_id,
network_tuple=network_tuple,
) )
else: else:
stream_token = yield self.store.get_room_max_stream_ordering() stream_token = yield self.store.get_room_max_stream_ordering()
public_room_stream_id = yield self.store.get_current_public_room_stream_id() public_room_stream_id = yield self.store.get_current_public_room_stream_id()
room_ids = yield self.store.get_public_room_ids_at_stream_id( room_ids = yield self.store.get_public_room_ids_at_stream_id(
public_room_stream_id public_room_stream_id, network_tuple=network_tuple,
) )
# We want to return rooms in a particular order: the number of joined # We want to return rooms in a particular order: the number of joined
@ -311,7 +338,8 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_remote_public_room_list(self, server_name, limit=None, since_token=None, def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
search_filter=None): search_filter=None, include_all_networks=False,
third_party_instance_id=None,):
if search_filter: if search_filter:
# We currently don't support searching across federation, so we have # We currently don't support searching across federation, so we have
# to do it manually without pagination # to do it manually without pagination
@ -320,6 +348,8 @@ class RoomListHandler(BaseHandler):
res = yield self._get_remote_list_cached( res = yield self._get_remote_list_cached(
server_name, limit=limit, since_token=since_token, server_name, limit=limit, since_token=since_token,
include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
) )
if search_filter: if search_filter:
@ -332,22 +362,30 @@ class RoomListHandler(BaseHandler):
defer.returnValue(res) defer.returnValue(res)
def _get_remote_list_cached(self, server_name, limit=None, since_token=None, def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
search_filter=None): search_filter=None, include_all_networks=False,
third_party_instance_id=None,):
repl_layer = self.hs.get_replication_layer() repl_layer = self.hs.get_replication_layer()
if search_filter: if search_filter:
# We can't cache when asking for search # We can't cache when asking for search
return repl_layer.get_public_rooms( return repl_layer.get_public_rooms(
server_name, limit=limit, since_token=since_token, server_name, limit=limit, since_token=since_token,
search_filter=search_filter, search_filter=search_filter, include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
) )
result = self.remote_response_cache.get((server_name, limit, since_token)) key = (
server_name, limit, since_token, include_all_networks,
third_party_instance_id,
)
result = self.remote_response_cache.get(key)
if not result: if not result:
result = self.remote_response_cache.set( result = self.remote_response_cache.set(
(server_name, limit, since_token), key,
repl_layer.get_public_rooms( repl_layer.get_public_rooms(
server_name, limit=limit, since_token=since_token, server_name, limit=limit, since_token=since_token,
search_filter=search_filter, search_filter=search_filter,
include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
) )
) )
return result return result

View File

@ -277,6 +277,7 @@ class SyncHandler(object):
""" """
with Measure(self.clock, "load_filtered_recents"): with Measure(self.clock, "load_filtered_recents"):
timeline_limit = sync_config.filter_collection.timeline_limit() timeline_limit = sync_config.filter_collection.timeline_limit()
block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline()
if recents is None or newly_joined_room or timeline_limit < len(recents): if recents is None or newly_joined_room or timeline_limit < len(recents):
limited = True limited = True
@ -293,7 +294,7 @@ class SyncHandler(object):
else: else:
recents = [] recents = []
if not limited: if not limited or block_all_timeline:
defer.returnValue(TimelineBatch( defer.returnValue(TimelineBatch(
events=recents, events=recents,
prev_batch=now_token, prev_batch=now_token,
@ -509,6 +510,7 @@ class SyncHandler(object):
Returns: Returns:
Deferred(SyncResult) Deferred(SyncResult)
""" """
logger.info("Calculating sync response for %r", sync_config.user)
# NB: The now_token gets changed by some of the generate_sync_* methods, # NB: The now_token gets changed by some of the generate_sync_* methods,
# this is due to some of the underlying streams not supporting the ability # this is due to some of the underlying streams not supporting the ability
@ -531,9 +533,14 @@ class SyncHandler(object):
) )
newly_joined_rooms, newly_joined_users = res newly_joined_rooms, newly_joined_users = res
yield self._generate_sync_entry_for_presence( block_all_presence_data = (
sync_result_builder, newly_joined_rooms, newly_joined_users since_token is None and
sync_config.filter_collection.blocks_all_presence()
) )
if not block_all_presence_data:
yield self._generate_sync_entry_for_presence(
sync_result_builder, newly_joined_rooms, newly_joined_users
)
yield self._generate_sync_entry_for_to_device(sync_result_builder) yield self._generate_sync_entry_for_to_device(sync_result_builder)
@ -569,16 +576,20 @@ class SyncHandler(object):
# We only delete messages when a new message comes in, but that's # We only delete messages when a new message comes in, but that's
# fine so long as we delete them at some point. # fine so long as we delete them at some point.
logger.debug("Deleting messages up to %d", since_stream_id) deleted = yield self.store.delete_messages_for_device(
yield self.store.delete_messages_for_device(
user_id, device_id, since_stream_id user_id, device_id, since_stream_id
) )
logger.info("Deleted %d to-device messages up to %d",
deleted, since_stream_id)
logger.debug("Getting messages up to %d", now_token.to_device_key)
messages, stream_id = yield self.store.get_new_messages_for_device( messages, stream_id = yield self.store.get_new_messages_for_device(
user_id, device_id, since_stream_id, now_token.to_device_key user_id, device_id, since_stream_id, now_token.to_device_key
) )
logger.debug("Got messages up to %d: %r", stream_id, messages)
logger.info(
"Returning %d to-device messages between %d and %d (current token: %d)",
len(messages), since_stream_id, stream_id, now_token.to_device_key
)
sync_result_builder.now_token = now_token.copy_and_replace( sync_result_builder.now_token = now_token.copy_and_replace(
"to_device_key", stream_id "to_device_key", stream_id
) )
@ -709,13 +720,20 @@ class SyncHandler(object):
`(newly_joined_rooms, newly_joined_users)` `(newly_joined_rooms, newly_joined_users)`
""" """
user_id = sync_result_builder.sync_config.user.to_string() user_id = sync_result_builder.sync_config.user.to_string()
block_all_room_ephemeral = (
now_token, ephemeral_by_room = yield self.ephemeral_by_room( sync_result_builder.since_token is None and
sync_result_builder.sync_config, sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
now_token=sync_result_builder.now_token,
since_token=sync_result_builder.since_token,
) )
sync_result_builder.now_token = now_token
if block_all_room_ephemeral:
ephemeral_by_room = {}
else:
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_result_builder.sync_config,
now_token=sync_result_builder.now_token,
since_token=sync_result_builder.since_token,
)
sync_result_builder.now_token = now_token
ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id=user_id, "m.ignored_user_list", user_id=user_id,

View File

@ -55,9 +55,9 @@ class TypingHandler(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.wheel_timer = WheelTimer(bucket_size=5000) self.wheel_timer = WheelTimer(bucket_size=5000)
self.federation = hs.get_replication_layer() self.federation = hs.get_federation_sender()
self.federation.register_edu_handler("m.typing", self._recv_edu) hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu)
hs.get_distributor().observe("user_left_room", self.user_left_room) hs.get_distributor().observe("user_left_room", self.user_left_room)

View File

@ -33,6 +33,7 @@ from synapse.api.errors import (
from signedjson.sign import sign_json from signedjson.sign import sign_json
import cgi
import simplejson as json import simplejson as json
import logging import logging
import random import random
@ -292,12 +293,7 @@ class MatrixFederationHttpClient(object):
if 200 <= response.code < 300: if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent? # We need to update the transactions table to say it was sent?
c_type = response.headers.getRawHeaders("Content-Type") check_content_type_is_json(response.headers)
if "application/json" not in c_type:
raise RuntimeError(
"Content-Type not application/json"
)
body = yield preserve_context_over_fn(readBody, response) body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@ -342,12 +338,7 @@ class MatrixFederationHttpClient(object):
if 200 <= response.code < 300: if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent? # We need to update the transactions table to say it was sent?
c_type = response.headers.getRawHeaders("Content-Type") check_content_type_is_json(response.headers)
if "application/json" not in c_type:
raise RuntimeError(
"Content-Type not application/json"
)
body = yield preserve_context_over_fn(readBody, response) body = yield preserve_context_over_fn(readBody, response)
@ -400,12 +391,7 @@ class MatrixFederationHttpClient(object):
if 200 <= response.code < 300: if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent? # We need to update the transactions table to say it was sent?
c_type = response.headers.getRawHeaders("Content-Type") check_content_type_is_json(response.headers)
if "application/json" not in c_type:
raise RuntimeError(
"Content-Type not application/json"
)
body = yield preserve_context_over_fn(readBody, response) body = yield preserve_context_over_fn(readBody, response)
@ -525,3 +511,29 @@ def _flatten_response_never_received(e):
) )
else: else:
return "%s: %s" % (type(e).__name__, e.message,) return "%s: %s" % (type(e).__name__, e.message,)
def check_content_type_is_json(headers):
"""
Check that a set of HTTP headers have a Content-Type header, and that it
is application/json.
Args:
headers (twisted.web.http_headers.Headers): headers to check
Raises:
RuntimeError if the
"""
c_type = headers.getRawHeaders("Content-Type")
if c_type is None:
raise RuntimeError(
"No Content-Type header"
)
c_type = c_type[0] # only the first header
val, options = cgi.parse_header(c_type)
if val != "application/json":
raise RuntimeError(
"Content-Type not application/json: was '%s'" % c_type
)

View File

@ -78,12 +78,16 @@ def parse_boolean(request, name, default=None, required=False):
parameter is present and not one of "true" or "false". parameter is present and not one of "true" or "false".
""" """
if name in request.args: return parse_boolean_from_args(request.args, name, default, required)
def parse_boolean_from_args(args, name, default=None, required=False):
if name in args:
try: try:
return { return {
"true": True, "true": True,
"false": False, "false": False,
}[request.args[name][0]] }[args[name][0]]
except: except:
message = ( message = (
"Boolean query parameter %r must be one of" "Boolean query parameter %r must be one of"

View File

@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.util import DeferredTimedOutError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
@ -143,6 +144,12 @@ class Notifier(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
if hs.should_send_federation():
self.federation_sender = hs.get_federation_sender()
else:
self.federation_sender = None
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.clock.looping_call( self.clock.looping_call(
@ -220,6 +227,9 @@ class Notifier(object):
# poke any interested application service. # poke any interested application service.
self.appservice_handler.notify_interested_services(room_stream_id) self.appservice_handler.notify_interested_services(room_stream_id)
if self.federation_sender:
self.federation_sender.notify_new_events(room_stream_id)
if event.type == EventTypes.Member and event.membership == Membership.JOIN: if event.type == EventTypes.Member and event.membership == Membership.JOIN:
self._user_joined_room(event.state_key, event.room_id) self._user_joined_room(event.state_key, event.room_id)
@ -285,14 +295,7 @@ class Notifier(object):
result = None result = None
if timeout: if timeout:
# Will be set to a _NotificationListener that we'll be waiting on. end_time = self.clock.time_msec() + timeout
# Allows us to cancel it.
listener = None
def timed_out():
if listener:
listener.deferred.cancel()
timer = self.clock.call_later(timeout / 1000., timed_out)
prev_token = from_token prev_token = from_token
while not result: while not result:
@ -303,6 +306,10 @@ class Notifier(object):
if result: if result:
break break
now = self.clock.time_msec()
if end_time <= now:
break
# Now we wait for the _NotifierUserStream to be told there # Now we wait for the _NotifierUserStream to be told there
# is a new token. # is a new token.
# We need to supply the token we supplied to callback so # We need to supply the token we supplied to callback so
@ -310,11 +317,14 @@ class Notifier(object):
prev_token = current_token prev_token = current_token
listener = user_stream.new_listener(prev_token) listener = user_stream.new_listener(prev_token)
with PreserveLoggingContext(): with PreserveLoggingContext():
yield listener.deferred yield self.clock.time_bound_deferred(
listener.deferred,
time_out=(end_time - now) / 1000.
)
except DeferredTimedOutError:
break
except defer.CancelledError: except defer.CancelledError:
break break
self.clock.cancel_call_later(timer, ignore_errs=True)
else: else:
current_token = user_stream.current_token current_token = user_stream.current_token
result = yield callback(from_token, current_token) result = yield callback(from_token, current_token)
@ -483,22 +493,27 @@ class Notifier(object):
""" """
listener = _NotificationListener(None) listener = _NotificationListener(None)
def timed_out(): end_time = self.clock.time_msec() + timeout
listener.deferred.cancel()
timer = self.clock.call_later(timeout / 1000., timed_out)
while True: while True:
listener.deferred = self.replication_deferred.observe() listener.deferred = self.replication_deferred.observe()
result = yield callback() result = yield callback()
if result: if result:
break break
now = self.clock.time_msec()
if end_time <= now:
break
try: try:
with PreserveLoggingContext(): with PreserveLoggingContext():
yield listener.deferred yield self.clock.time_bound_deferred(
listener.deferred,
time_out=(end_time - now) / 1000.
)
except DeferredTimedOutError:
break
except defer.CancelledError: except defer.CancelledError:
break break
self.clock.cancel_call_later(timer, ignore_errs=True)
defer.returnValue(result) defer.returnValue(result)

View File

@ -87,12 +87,12 @@ class BulkPushRuleEvaluator:
condition_cache = {} condition_cache = {}
for uid, rules in self.rules_by_user.items(): for uid, rules in self.rules_by_user.items():
display_name = None display_name = room_members.get(uid, {}).get("display_name", None)
member_ev_id = context.current_state_ids.get((EventTypes.Member, uid)) if not display_name:
if member_ev_id: # Handle the case where we are pushing a membership event to
member_ev = yield self.store.get_event(member_ev_id, allow_none=True) # that user, as they might not be already joined.
if member_ev: if event.type == EventTypes.Member and event.state_key == uid:
display_name = member_ev.content.get("displayname", None) display_name = event.content.get("displayname", None)
filtered = filtered_by_user[uid] filtered = filtered_by_user[uid]
if len(filtered) == 0: if len(filtered) == 0:

View File

@ -49,8 +49,8 @@ CONDITIONAL_REQUIREMENTS = {
"Jinja2>=2.8": ["Jinja2>=2.8"], "Jinja2>=2.8": ["Jinja2>=2.8"],
"bleach>=1.4.2": ["bleach>=1.4.2"], "bleach>=1.4.2": ["bleach>=1.4.2"],
}, },
"ldap": { "matrix-synapse-ldap3": {
"ldap3>=1.0": ["ldap3>=1.0"], "matrix-synapse-ldap3>=0.1": ["ldap_auth_provider"],
}, },
"psutil": { "psutil": {
"psutil>=2.0.0": ["psutil>=2.0.0"], "psutil>=2.0.0": ["psutil>=2.0.0"],
@ -69,6 +69,7 @@ def requirements(config=None, include_conditional=False):
def github_link(project, version, egg): def github_link(project, version, egg):
return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg) return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg)
DEPENDENCY_LINKS = { DEPENDENCY_LINKS = {
} }
@ -156,6 +157,7 @@ def list_requirements():
result.append(requirement) result.append(requirement)
return result return result
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
sys.stdout.writelines(req + "\n" for req in list_requirements()) sys.stdout.writelines(req + "\n" for req in list_requirements())

View File

@ -0,0 +1,60 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import respond_with_json_bytes, request_handler
from synapse.http.servlet import parse_json_object_from_request
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
class ExpireCacheResource(Resource):
"""
HTTP endpoint for expiring storage caches.
POST /_synapse/replication/expire_cache HTTP/1.1
Content-Type: application/json
{
"invalidate": [
{
"name": "func_name",
"keys": ["key1", "key2"]
}
]
}
"""
def __init__(self, hs):
Resource.__init__(self) # Resource is old-style, so no super()
self.store = hs.get_datastore()
self.version_string = hs.version_string
self.clock = hs.get_clock()
def render_POST(self, request):
self._async_render_POST(request)
return NOT_DONE_YET
@request_handler()
def _async_render_POST(self, request):
content = parse_json_object_from_request(request)
for row in content["invalidate"]:
name = row["name"]
keys = tuple(row["keys"])
getattr(self.store, name).invalidate(keys)
respond_with_json_bytes(request, 200, "{}")

View File

@ -17,6 +17,7 @@ from synapse.http.servlet import parse_integer, parse_string
from synapse.http.server import request_handler, finish_request from synapse.http.server import request_handler, finish_request
from synapse.replication.pusher_resource import PusherResource from synapse.replication.pusher_resource import PusherResource
from synapse.replication.presence_resource import PresenceResource from synapse.replication.presence_resource import PresenceResource
from synapse.replication.expire_cache import ExpireCacheResource
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from twisted.web.resource import Resource from twisted.web.resource import Resource
@ -44,6 +45,7 @@ STREAM_NAMES = (
("caches",), ("caches",),
("to_device",), ("to_device",),
("public_rooms",), ("public_rooms",),
("federation",),
) )
@ -116,11 +118,14 @@ class ReplicationResource(Resource):
self.sources = hs.get_event_sources() self.sources = hs.get_event_sources()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.typing_handler = hs.get_typing_handler() self.typing_handler = hs.get_typing_handler()
self.federation_sender = hs.get_federation_sender()
self.notifier = hs.notifier self.notifier = hs.notifier
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.config = hs.get_config()
self.putChild("remove_pushers", PusherResource(hs)) self.putChild("remove_pushers", PusherResource(hs))
self.putChild("syncing_users", PresenceResource(hs)) self.putChild("syncing_users", PresenceResource(hs))
self.putChild("expire_cache", ExpireCacheResource(hs))
def render_GET(self, request): def render_GET(self, request):
self._async_render_GET(request) self._async_render_GET(request)
@ -134,6 +139,7 @@ class ReplicationResource(Resource):
pushers_token = self.store.get_pushers_stream_token() pushers_token = self.store.get_pushers_stream_token()
caches_token = self.store.get_cache_stream_token() caches_token = self.store.get_cache_stream_token()
public_rooms_token = self.store.get_current_public_room_stream_id() public_rooms_token = self.store.get_current_public_room_stream_id()
federation_token = self.federation_sender.get_current_token()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
room_stream_token, room_stream_token,
@ -148,6 +154,7 @@ class ReplicationResource(Resource):
caches_token, caches_token,
int(stream_token.to_device_key), int(stream_token.to_device_key),
int(public_rooms_token), int(public_rooms_token),
int(federation_token),
)) ))
@request_handler() @request_handler()
@ -164,8 +171,13 @@ class ReplicationResource(Resource):
} }
request_streams["streams"] = parse_string(request, "streams") request_streams["streams"] = parse_string(request, "streams")
federation_ack = parse_integer(request, "federation_ack", None)
def replicate(): def replicate():
return self.replicate(request_streams, limit) return self.replicate(
request_streams, limit,
federation_ack=federation_ack
)
writer = yield self.notifier.wait_for_replication(replicate, timeout) writer = yield self.notifier.wait_for_replication(replicate, timeout)
result = writer.finish() result = writer.finish()
@ -183,7 +195,7 @@ class ReplicationResource(Resource):
finish_request(request) finish_request(request)
@defer.inlineCallbacks @defer.inlineCallbacks
def replicate(self, request_streams, limit): def replicate(self, request_streams, limit, federation_ack=None):
writer = _Writer() writer = _Writer()
current_token = yield self.current_replication_token() current_token = yield self.current_replication_token()
logger.debug("Replicating up to %r", current_token) logger.debug("Replicating up to %r", current_token)
@ -202,6 +214,7 @@ class ReplicationResource(Resource):
yield self.caches(writer, current_token, limit, request_streams) yield self.caches(writer, current_token, limit, request_streams)
yield self.to_device(writer, current_token, limit, request_streams) yield self.to_device(writer, current_token, limit, request_streams)
yield self.public_rooms(writer, current_token, limit, request_streams) yield self.public_rooms(writer, current_token, limit, request_streams)
self.federation(writer, current_token, limit, request_streams, federation_ack)
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
logger.debug("Replicated %d rows", writer.total) logger.debug("Replicated %d rows", writer.total)
@ -462,7 +475,24 @@ class ReplicationResource(Resource):
) )
upto_token = _position_from_rows(public_rooms_rows, current_position) upto_token = _position_from_rows(public_rooms_rows, current_position)
writer.write_header_and_rows("public_rooms", public_rooms_rows, ( writer.write_header_and_rows("public_rooms", public_rooms_rows, (
"position", "room_id", "visibility" "position", "room_id", "visibility", "appservice_id", "network_id",
), position=upto_token)
def federation(self, writer, current_token, limit, request_streams, federation_ack):
if self.config.send_federation:
return
current_position = current_token.federation
federation = request_streams.get("federation")
if federation is not None and federation != current_position:
federation_rows = self.federation_sender.get_replication_rows(
federation, limit, federation_ack=federation_ack,
)
upto_token = _position_from_rows(federation_rows, current_position)
writer.write_header_and_rows("federation", federation_rows, (
"position", "type", "content",
), position=upto_token) ), position=upto_token)
@ -497,6 +527,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill", "events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state", "caches", "to_device", "public_rooms", "push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
"federation",
))): ))):
__slots__ = [] __slots__ = []

View File

@ -34,6 +34,9 @@ class BaseSlavedStore(SQLBaseStore):
else: else:
self._cache_id_gen = None self._cache_id_gen = None
self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache"
self.http_client = hs.get_simple_http_client()
def stream_positions(self): def stream_positions(self):
pos = {} pos = {}
if self._cache_id_gen: if self._cache_id_gen:
@ -54,3 +57,19 @@ class BaseSlavedStore(SQLBaseStore):
logger.info("Got unexpected cache_func: %r", cache_func) logger.info("Got unexpected cache_func: %r", cache_func)
self._cache_id_gen.advance(int(stream["position"])) self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None) return defer.succeed(None)
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
txn.call_after(cache_func.invalidate, keys)
txn.call_after(self._send_invalidation_poke, cache_func, keys)
@defer.inlineCallbacks
def _send_invalidation_poke(self, cache_func, keys):
try:
yield self.http_client.post_json_get_json(self.expire_cache_url, {
"invalidate": [{
"name": cache_func.__name__,
"keys": list(keys),
}]
})
except:
logger.exception("Failed to poke on expire_cache")

View File

@ -29,10 +29,16 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
"DeviceInboxStreamChangeCache", "DeviceInboxStreamChangeCache",
self._device_inbox_id_gen.get_current_token() self._device_inbox_id_gen.get_current_token()
) )
self._device_federation_outbox_stream_cache = StreamChangeCache(
"DeviceFederationOutboxStreamChangeCache",
self._device_inbox_id_gen.get_current_token()
)
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__ get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__
delete_messages_for_device = DataStore.delete_messages_for_device.__func__ delete_messages_for_device = DataStore.delete_messages_for_device.__func__
delete_device_msgs_for_remote = DataStore.delete_device_msgs_for_remote.__func__
def stream_positions(self): def stream_positions(self):
result = super(SlavedDeviceInboxStore, self).stream_positions() result = super(SlavedDeviceInboxStore, self).stream_positions()
@ -45,9 +51,15 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
self._device_inbox_id_gen.advance(int(stream["position"])) self._device_inbox_id_gen.advance(int(stream["position"]))
for row in stream["rows"]: for row in stream["rows"]:
stream_id = row[0] stream_id = row[0]
user_id = row[1] entity = row[1]
self._device_inbox_stream_cache.entity_has_changed(
user_id, stream_id if entity.startswith("@"):
) self._device_inbox_stream_cache.entity_has_changed(
entity, stream_id
)
else:
self._device_federation_outbox_stream_cache.entity_has_changed(
entity, stream_id
)
return super(SlavedDeviceInboxStore, self).process_replication(result) return super(SlavedDeviceInboxStore, self).process_replication(result)

View File

@ -26,6 +26,11 @@ from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
import ujson as json import ujson as json
import logging
logger = logging.getLogger(__name__)
# So, um, we want to borrow a load of functions intended for reading from # So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the # a DataStore, but we don't want to take functions that either write to the
@ -180,6 +185,11 @@ class SlavedEventStore(BaseSlavedStore):
EventFederationStore.__dict__["_get_forward_extremeties_for_room"] EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
) )
get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__
get_federation_out_pos = DataStore.get_federation_out_pos.__func__
update_federation_out_pos = DataStore.update_federation_out_pos.__func__
def stream_positions(self): def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions() result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token() result["events"] = self._stream_id_gen.get_current_token()
@ -194,6 +204,10 @@ class SlavedEventStore(BaseSlavedStore):
stream = result.get("events") stream = result.get("events")
if stream: if stream:
self._stream_id_gen.advance(int(stream["position"])) self._stream_id_gen.advance(int(stream["position"]))
if stream["rows"]:
logger.info("Got %d event rows", len(stream["rows"]))
for row in stream["rows"]: for row in stream["rows"]:
self._process_replication_row( self._process_replication_row(
row, backfilled=False, state_resets=state_resets row, backfilled=False, state_resets=state_resets

View File

@ -15,6 +15,7 @@
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.room import RoomStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
@ -30,7 +31,7 @@ class RoomStore(BaseSlavedStore):
DataStore.get_current_public_room_stream_id.__func__ DataStore.get_current_public_room_stream_id.__func__
) )
get_public_room_ids_at_stream_id = ( get_public_room_ids_at_stream_id = (
DataStore.get_public_room_ids_at_stream_id.__func__ RoomStore.__dict__["get_public_room_ids_at_stream_id"]
) )
get_public_room_ids_at_stream_id_txn = ( get_public_room_ids_at_stream_id_txn = (
DataStore.get_public_room_ids_at_stream_id_txn.__func__ DataStore.get_public_room_ids_at_stream_id_txn.__func__

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.transactions import TransactionStore from synapse.storage.transactions import TransactionStore
@ -22,9 +21,10 @@ from synapse.storage.transactions import TransactionStore
class TransactionStore(BaseSlavedStore): class TransactionStore(BaseSlavedStore):
get_destination_retry_timings = TransactionStore.__dict__[ get_destination_retry_timings = TransactionStore.__dict__[
"get_destination_retry_timings" "get_destination_retry_timings"
].orig ]
_get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__ _get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
set_destination_retry_timings = DataStore.set_destination_retry_timings.__func__
_set_destination_retry_timings = DataStore._set_destination_retry_timings.__func__
# For now, don't record the destination rety timings prep_send_transaction = DataStore.prep_send_transaction.__func__
def set_destination_retry_timings(*args, **kwargs): delivered_txn = DataStore.delivered_txn.__func__
return defer.succeed(None)

View File

@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
ClientDirectoryServer(hs).register(http_server) ClientDirectoryServer(hs).register(http_server)
ClientDirectoryListServer(hs).register(http_server) ClientDirectoryListServer(hs).register(http_server)
ClientAppserviceDirectoryListServer(hs).register(http_server)
class ClientDirectoryServer(ClientV1RestServlet): class ClientDirectoryServer(ClientV1RestServlet):
@ -184,3 +185,36 @@ class ClientDirectoryListServer(ClientV1RestServlet):
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
class ClientAppserviceDirectoryListServer(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$"
)
def __init__(self, hs):
super(ClientAppserviceDirectoryListServer, self).__init__(hs)
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
def on_PUT(self, request, network_id, room_id):
content = parse_json_object_from_request(request)
visibility = content.get("visibility", "public")
return self._edit(request, network_id, room_id, visibility)
def on_DELETE(self, request, network_id, room_id):
return self._edit(request, network_id, room_id, "private")
@defer.inlineCallbacks
def _edit(self, request, network_id, room_id, visibility):
requester = yield self.auth.get_user_by_req(request)
if not requester.app_service:
raise AuthError(
403, "Only appservices can edit the appservice published room list"
)
yield self.handlers.directory_handler.edit_published_appservice_room_list(
requester.app_service.id, network_id, room_id, visibility,
)
defer.returnValue((200, {}))

View File

@ -137,16 +137,13 @@ class LoginRestServlet(ClientV1RestServlet):
password=login_submission["password"], password=login_submission["password"],
) )
device_id = yield self._register_device(user_id, login_submission) device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = ( access_token = yield auth_handler.get_access_token_for_user_id(
yield auth_handler.get_login_tuple_for_user_id( user_id, device_id,
user_id, device_id, login_submission.get("initial_device_display_name"),
login_submission.get("initial_device_display_name")
)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
@ -161,16 +158,13 @@ class LoginRestServlet(ClientV1RestServlet):
yield auth_handler.validate_short_term_login_token_and_get_user_id(token) yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
) )
device_id = yield self._register_device(user_id, login_submission) device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = ( access_token = yield auth_handler.get_access_token_for_user_id(
yield auth_handler.get_login_tuple_for_user_id( user_id, device_id,
user_id, device_id, login_submission.get("initial_device_display_name"),
login_submission.get("initial_device_display_name")
)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
@ -207,16 +201,14 @@ class LoginRestServlet(ClientV1RestServlet):
device_id = yield self._register_device( device_id = yield self._register_device(
registered_user_id, login_submission registered_user_id, login_submission
) )
access_token, refresh_token = ( access_token = yield auth_handler.get_access_token_for_user_id(
yield auth_handler.get_login_tuple_for_user_id( registered_user_id, device_id,
registered_user_id, device_id, login_submission.get("initial_device_display_name"),
login_submission.get("initial_device_display_name")
)
) )
result = { result = {
"user_id": registered_user_id, "user_id": registered_user_id,
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} }
else: else:

View File

@ -384,7 +384,6 @@ class CreateUserRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(CreateUserRestServlet, self).__init__(hs) super(CreateUserRestServlet, self).__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
@ -418,18 +417,8 @@ class CreateUserRestServlet(ClientV1RestServlet):
if "displayname" not in user_json: if "displayname" not in user_json:
raise SynapseError(400, "Expected 'displayname' key.") raise SynapseError(400, "Expected 'displayname' key.")
if "duration_seconds" not in user_json:
raise SynapseError(400, "Expected 'duration_seconds' key.")
localpart = user_json["localpart"].encode("utf-8") localpart = user_json["localpart"].encode("utf-8")
displayname = user_json["displayname"].encode("utf-8") displayname = user_json["displayname"].encode("utf-8")
duration_seconds = 0
try:
duration_seconds = int(user_json["duration_seconds"])
except ValueError:
raise SynapseError(400, "Failed to parse 'duration_seconds'")
if duration_seconds > self.direct_user_creation_max_duration:
duration_seconds = self.direct_user_creation_max_duration
password_hash = user_json["password_hash"].encode("utf-8") \ password_hash = user_json["password_hash"].encode("utf-8") \
if user_json.get("password_hash") else None if user_json.get("password_hash") else None
@ -438,7 +427,6 @@ class CreateUserRestServlet(ClientV1RestServlet):
requester=requester, requester=requester,
localpart=localpart, localpart=localpart,
displayname=displayname, displayname=displayname,
duration_in_ms=(duration_seconds * 1000),
password_hash=password_hash password_hash=password_hash
) )

View File

@ -21,7 +21,7 @@ from synapse.api.errors import SynapseError, Codes, AuthError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias, ThirdPartyInstanceID
from synapse.events.utils import serialize_event, format_event_for_client_v2 from synapse.events.utils import serialize_event, format_event_for_client_v2
from synapse.http.servlet import ( from synapse.http.servlet import (
parse_json_object_from_request, parse_string, parse_integer parse_json_object_from_request, parse_string, parse_integer
@ -321,6 +321,20 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
since_token = content.get("since", None) since_token = content.get("since", None)
search_filter = content.get("filter", None) search_filter = content.get("filter", None)
include_all_networks = content.get("include_all_networks", False)
third_party_instance_id = content.get("third_party_instance_id", None)
if include_all_networks:
network_tuple = None
if third_party_instance_id is not None:
raise SynapseError(
400, "Can't use include_all_networks with an explicit network"
)
elif third_party_instance_id is None:
network_tuple = ThirdPartyInstanceID(None, None)
else:
network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
handler = self.hs.get_room_list_handler() handler = self.hs.get_room_list_handler()
if server: if server:
data = yield handler.get_remote_public_room_list( data = yield handler.get_remote_public_room_list(
@ -328,12 +342,15 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
limit=limit, limit=limit,
since_token=since_token, since_token=since_token,
search_filter=search_filter, search_filter=search_filter,
include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
) )
else: else:
data = yield handler.get_local_public_room_list( data = yield handler.get_local_public_room_list(
limit=limit, limit=limit,
since_token=since_token, since_token=since_token,
search_filter=search_filter, search_filter=search_filter,
network_tuple=network_tuple,
) )
defer.returnValue((200, data)) defer.returnValue((200, data))
@ -369,6 +386,24 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
})) }))
class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$")
def __init__(self, hs):
super(JoinedRoomMemberListRestServlet, self).__init__(hs)
self.state = hs.get_state_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
yield self.auth.get_user_by_req(request)
users_with_profile = yield self.state.get_current_user_in_room(room_id)
defer.returnValue((200, {
"joined": users_with_profile
}))
# TODO: Needs better unit testing # TODO: Needs better unit testing
class RoomMessageListRestServlet(ClientV1RestServlet): class RoomMessageListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
@ -692,6 +727,22 @@ class SearchRestServlet(ClientV1RestServlet):
defer.returnValue((200, results)) defer.returnValue((200, results))
class JoinedRoomsRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/joined_rooms$")
def __init__(self, hs):
super(JoinedRoomsRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
rooms = yield self.store.get_rooms_for_user(requester.user.to_string())
room_ids = set(r.room_id for r in rooms) # Ensure they're unique.
defer.returnValue((200, {"joined_rooms": list(room_ids)}))
def register_txn_path(servlet, regex_string, http_server, with_get=False): def register_txn_path(servlet, regex_string, http_server, with_get=False):
"""Registers a transaction-based path. """Registers a transaction-based path.
@ -727,6 +778,7 @@ def register_servlets(hs, http_server):
RoomStateEventRestServlet(hs).register(http_server) RoomStateEventRestServlet(hs).register(http_server)
RoomCreateRestServlet(hs).register(http_server) RoomCreateRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server)
JoinedRoomMemberListRestServlet(hs).register(http_server)
RoomMessageListRestServlet(hs).register(http_server) RoomMessageListRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server)
RoomForgetRestServlet(hs).register(http_server) RoomForgetRestServlet(hs).register(http_server)
@ -738,4 +790,5 @@ def register_servlets(hs, http_server):
RoomRedactEventRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server)
RoomTypingRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server)
JoinedRoomsRestServlet(hs).register(http_server)
RoomEventContext(hs).register(http_server) RoomEventContext(hs).register(http_server)

View File

@ -39,7 +39,7 @@ class DevicesRestServlet(servlet.RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
devices = yield self.device_handler.get_devices_by_user( devices = yield self.device_handler.get_devices_by_user(
requester.user.to_string() requester.user.to_string()
) )
@ -63,7 +63,7 @@ class DeviceRestServlet(servlet.RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, device_id): def on_GET(self, request, device_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
device = yield self.device_handler.get_device( device = yield self.device_handler.get_device(
requester.user.to_string(), requester.user.to_string(),
device_id, device_id,
@ -99,7 +99,7 @@ class DeviceRestServlet(servlet.RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, device_id): def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
body = servlet.parse_json_object_from_request(request) body = servlet.parse_json_object_from_request(request)
yield self.device_handler.update_device( yield self.device_handler.update_device(

View File

@ -65,7 +65,7 @@ class KeyUploadServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, device_id): def on_POST(self, request, device_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -94,10 +94,6 @@ class KeyUploadServlet(RestServlet):
class KeyQueryServlet(RestServlet): class KeyQueryServlet(RestServlet):
""" """
GET /keys/query/<user_id> HTTP/1.1
GET /keys/query/<user_id>/<device_id> HTTP/1.1
POST /keys/query HTTP/1.1 POST /keys/query HTTP/1.1
Content-Type: application/json Content-Type: application/json
{ {
@ -131,11 +127,7 @@ class KeyQueryServlet(RestServlet):
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_v2_patterns(
"/keys/query(?:" "/keys/query$",
"/(?P<user_id>[^/]*)(?:"
"/(?P<device_id>[^/]*)"
")?"
")?",
releases=() releases=()
) )
@ -149,31 +141,16 @@ class KeyQueryServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id, device_id): def on_POST(self, request):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.query_devices(body, timeout) result = yield self.e2e_keys_handler.query_devices(body, timeout)
defer.returnValue((200, result)) defer.returnValue((200, result))
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id):
requester = yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000)
auth_user_id = requester.user.to_string()
user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else []
result = yield self.e2e_keys_handler.query_devices(
{"device_keys": {user_id: device_ids}},
timeout,
)
defer.returnValue((200, result))
class OneTimeKeyServlet(RestServlet): class OneTimeKeyServlet(RestServlet):
""" """
GET /keys/claim/<user-id>/<device-id>/<algorithm> HTTP/1.1
POST /keys/claim HTTP/1.1 POST /keys/claim HTTP/1.1
{ {
"one_time_keys": { "one_time_keys": {
@ -191,9 +168,7 @@ class OneTimeKeyServlet(RestServlet):
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_v2_patterns(
"/keys/claim(?:/?|(?:/" "/keys/claim$",
"(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)"
")?)",
releases=() releases=()
) )
@ -203,18 +178,8 @@ class OneTimeKeyServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, device_id, algorithm): def on_POST(self, request):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
result = yield self.e2e_keys_handler.claim_one_time_keys(
{"one_time_keys": {user_id: {device_id: algorithm}}},
timeout,
)
defer.returnValue((200, result))
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.claim_one_time_keys( result = yield self.e2e_keys_handler.claim_one_time_keys(

View File

@ -36,7 +36,7 @@ class ReceiptRestServlet(RestServlet):
super(ReceiptRestServlet, self).__init__() super(ReceiptRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.receipts_handler = hs.get_handlers().receipts_handler self.receipts_handler = hs.get_receipts_handler()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -15,6 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
import synapse
from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
@ -100,12 +101,14 @@ class RegisterRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
yield run_on_reactor() yield run_on_reactor()
body = parse_json_object_from_request(request)
kind = "user" kind = "user"
if "kind" in request.args: if "kind" in request.args:
kind = request.args["kind"][0] kind = request.args["kind"][0]
if kind == "guest": if kind == "guest":
ret = yield self._do_guest_registration() ret = yield self._do_guest_registration(body)
defer.returnValue(ret) defer.returnValue(ret)
return return
elif kind != "user": elif kind != "user":
@ -113,8 +116,6 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind,) "Do not understand membership kind: %s" % (kind,)
) )
body = parse_json_object_from_request(request)
# we do basic sanity checks here because the auth layer will store these # we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us. # in sessions. Pull out the username/password provided to us.
desired_password = None desired_password = None
@ -373,8 +374,7 @@ class RegisterRestServlet(RestServlet):
def _create_registration_details(self, user_id, params): def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user """Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token Allocates device_id if one was not given; also creates access_token.
and refresh_token.
Args: Args:
(str) user_id: full canonical @user:id (str) user_id: full canonical @user:id
@ -385,8 +385,8 @@ class RegisterRestServlet(RestServlet):
""" """
device_id = yield self._register_device(user_id, params) device_id = yield self._register_device(user_id, params)
access_token, refresh_token = ( access_token = (
yield self.auth_handler.get_login_tuple_for_user_id( yield self.auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id, user_id, device_id=device_id,
initial_display_name=params.get("initial_device_display_name") initial_display_name=params.get("initial_device_display_name")
) )
@ -396,7 +396,6 @@ class RegisterRestServlet(RestServlet):
"user_id": user_id, "user_id": user_id,
"access_token": access_token, "access_token": access_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"refresh_token": refresh_token,
"device_id": device_id, "device_id": device_id,
}) })
@ -421,20 +420,28 @@ class RegisterRestServlet(RestServlet):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_guest_registration(self): def _do_guest_registration(self, params):
if not self.hs.config.allow_guest_access: if not self.hs.config.allow_guest_access:
defer.returnValue((403, "Guest access is disabled")) defer.returnValue((403, "Guest access is disabled"))
user_id, _ = yield self.registration_handler.register( user_id, _ = yield self.registration_handler.register(
generate_token=False, generate_token=False,
make_guest=True make_guest=True
) )
# we don't allow guests to specify their own device_id, because
# we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name")
self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
access_token = self.auth_handler.generate_access_token( access_token = self.auth_handler.generate_access_token(
user_id, ["guest = true"] user_id, ["guest = true"]
) )
# XXX the "guest" caveat is not copied by /tokenrefresh. That's ok
# so long as we don't return a refresh_token here.
defer.returnValue((200, { defer.returnValue((200, {
"user_id": user_id, "user_id": user_id,
"device_id": device_id,
"access_token": access_token, "access_token": access_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
})) }))

View File

@ -50,7 +50,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def _put(self, request, message_type, txn_id): def _put(self, request, message_type, txn_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)

View File

@ -162,7 +162,7 @@ class SyncRestServlet(RestServlet):
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
joined = self.encode_joined( joined = self.encode_joined(
sync_result.joined, time_now, requester.access_token_id sync_result.joined, time_now, requester.access_token_id, filter.event_fields
) )
invited = self.encode_invited( invited = self.encode_invited(
@ -170,7 +170,7 @@ class SyncRestServlet(RestServlet):
) )
archived = self.encode_archived( archived = self.encode_archived(
sync_result.archived, time_now, requester.access_token_id sync_result.archived, time_now, requester.access_token_id, filter.event_fields
) )
response_content = { response_content = {
@ -197,7 +197,7 @@ class SyncRestServlet(RestServlet):
formatted.append(event) formatted.append(event)
return {"events": formatted} return {"events": formatted}
def encode_joined(self, rooms, time_now, token_id): def encode_joined(self, rooms, time_now, token_id, event_fields):
""" """
Encode the joined rooms in a sync result Encode the joined rooms in a sync result
@ -208,7 +208,8 @@ class SyncRestServlet(RestServlet):
calculations calculations
token_id(int): ID of the user's auth token - used for namespacing token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
event_fields(list<str>): List of event fields to include. If empty,
all fields will be returned.
Returns: Returns:
dict[str, dict[str, object]]: the joined rooms list, in our dict[str, dict[str, object]]: the joined rooms list, in our
response format response format
@ -216,7 +217,7 @@ class SyncRestServlet(RestServlet):
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = self.encode_room( joined[room.room_id] = self.encode_room(
room, time_now, token_id room, time_now, token_id, only_fields=event_fields
) )
return joined return joined
@ -253,7 +254,7 @@ class SyncRestServlet(RestServlet):
return invited return invited
def encode_archived(self, rooms, time_now, token_id): def encode_archived(self, rooms, time_now, token_id, event_fields):
""" """
Encode the archived rooms in a sync result Encode the archived rooms in a sync result
@ -264,7 +265,8 @@ class SyncRestServlet(RestServlet):
calculations calculations
token_id(int): ID of the user's auth token - used for namespacing token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
event_fields(list<str>): List of event fields to include. If empty,
all fields will be returned.
Returns: Returns:
dict[str, dict[str, object]]: The invited rooms list, in our dict[str, dict[str, object]]: The invited rooms list, in our
response format response format
@ -272,13 +274,13 @@ class SyncRestServlet(RestServlet):
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = self.encode_room( joined[room.room_id] = self.encode_room(
room, time_now, token_id, joined=False room, time_now, token_id, joined=False, only_fields=event_fields
) )
return joined return joined
@staticmethod @staticmethod
def encode_room(room, time_now, token_id, joined=True): def encode_room(room, time_now, token_id, joined=True, only_fields=None):
""" """
Args: Args:
room (JoinedSyncResult|ArchivedSyncResult): sync result for a room (JoinedSyncResult|ArchivedSyncResult): sync result for a
@ -289,7 +291,7 @@ class SyncRestServlet(RestServlet):
of transaction IDs of transaction IDs
joined (bool): True if the user is joined to this room - will mean joined (bool): True if the user is joined to this room - will mean
we handle ephemeral events we handle ephemeral events
only_fields(list<str>): Optional. The list of event fields to include.
Returns: Returns:
dict[str, object]: the room, encoded in our response format dict[str, object]: the room, encoded in our response format
""" """
@ -298,6 +300,7 @@ class SyncRestServlet(RestServlet):
return serialize_event( return serialize_event(
event, time_now, token_id=token_id, event, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_room_id, event_format=format_event_for_client_v2_without_room_id,
only_event_fields=only_fields,
) )
state_dict = room.state state_dict = room.state

View File

@ -15,8 +15,8 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_v2_patterns
@ -30,30 +30,10 @@ class TokenRefreshRestServlet(RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__() super(TokenRefreshRestServlet, self).__init__()
self.hs = hs
self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) raise AuthError(403, "tokenrefresh is no longer supported.")
try:
old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_auth_handler()
refresh_result = yield self.store.exchange_refresh_token(
old_refresh_token, auth_handler.generate_refresh_token
)
(user_id, new_refresh_token, device_id) = refresh_result
new_access_token = yield auth_handler.issue_access_token(
user_id, device_id
)
defer.returnValue((200, {
"access_token": new_access_token,
"refresh_token": new_refresh_token,
}))
except KeyError:
raise SynapseError(400, "Missing required key 'refresh_token'.")
except StoreError:
raise AuthError(403, "Did not recognize refresh token")
def register_servlets(hs, http_server): def register_servlets(hs, http_server):

View File

@ -381,7 +381,10 @@ def _calc_og(tree, media_uri):
if 'og:title' not in og: if 'og:title' not in og:
# do some basic spidering of the HTML # do some basic spidering of the HTML
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
og['og:title'] = title[0].text.strip() if title else None if title and title[0].text is not None:
og['og:title'] = title[0].text.strip()
else:
og['og:title'] = None
if 'og:image' not in og: if 'og:image' not in og:
# TODO: extract a favicon failing all else # TODO: extract a favicon failing all else
@ -543,5 +546,5 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
# We always add an ellipsis because at the very least # We always add an ellipsis because at the very least
# we chopped mid paragraph. # we chopped mid paragraph.
description = new_desc.strip() + "" description = new_desc.strip() + u""
return description if description else None return description if description else None

View File

@ -32,6 +32,9 @@ from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.federation.send_queue import FederationRemoteSendQueue
from synapse.federation.transport.client import TransportLayerClient
from synapse.federation.transaction_queue import TransactionQueue
from synapse.handlers import Handlers from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler from synapse.handlers.auth import AuthHandler
@ -44,6 +47,7 @@ from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler from synapse.handlers.typing import TypingHandler
from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.receipts import ReceiptsHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier from synapse.notifier import Notifier
@ -124,6 +128,9 @@ class HomeServer(object):
'http_client_context_factory', 'http_client_context_factory',
'simple_http_client', 'simple_http_client',
'media_repository', 'media_repository',
'federation_transport_client',
'federation_sender',
'receipts_handler',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -265,9 +272,30 @@ class HomeServer(object):
def build_media_repository(self): def build_media_repository(self):
return MediaRepository(self) return MediaRepository(self)
def build_federation_transport_client(self):
return TransportLayerClient(self)
def build_federation_sender(self):
if self.should_send_federation():
return TransactionQueue(self)
elif not self.config.worker_app:
return FederationRemoteSendQueue(self)
else:
raise Exception("Workers cannot send federation traffic")
def build_receipts_handler(self):
return ReceiptsHandler(self)
def remove_pusher(self, app_id, push_key, user_id): def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
def should_send_federation(self):
"Should this server be sending federation traffic directly?"
return self.config.send_federation and (
not self.config.worker_app
or self.config.worker_app == "synapse.app.federation_sender"
)
def _make_dependency_method(depname): def _make_dependency_method(depname):
def _get(hs): def _get(hs):

View File

@ -120,7 +120,6 @@ class DataStore(RoomMemberStore, RoomStore,
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
@ -223,6 +222,7 @@ class DataStore(RoomMemberStore, RoomStore,
) )
self._stream_order_on_start = self.get_room_max_stream_ordering() self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
super(DataStore, self).__init__(hs) super(DataStore, self).__init__(hs)

View File

@ -561,12 +561,17 @@ class SQLBaseStore(object):
@staticmethod @staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol): def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
else:
where = ""
sql = ( sql = (
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s" "SELECT %(retcol)s FROM %(table)s %(where)s"
) % { ) % {
"retcol": retcol, "retcol": retcol,
"table": table, "table": table,
"where": " AND ".join("%s = ?" % k for k in keyvalues.keys()), "where": where,
} }
txn.execute(sql, keyvalues.values()) txn.execute(sql, keyvalues.values())
@ -744,10 +749,15 @@ class SQLBaseStore(object):
@staticmethod @staticmethod
def _simple_update_one_txn(txn, table, keyvalues, updatevalues): def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
update_sql = "UPDATE %s SET %s WHERE %s" % ( if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
else:
where = ""
update_sql = "UPDATE %s SET %s %s" % (
table, table,
", ".join("%s = ?" % (k,) for k in updatevalues), ", ".join("%s = ?" % (k,) for k in updatevalues),
" AND ".join("%s = ?" % (k,) for k in keyvalues) where,
) )
txn.execute( txn.execute(

View File

@ -39,6 +39,14 @@ class ApplicationServiceStore(SQLBaseStore):
def get_app_services(self): def get_app_services(self):
return self.services_cache return self.services_cache
def get_if_app_services_interested_in_user(self, user_id):
"""Check if the user is one associated with an app service
"""
for service in self.services_cache:
if service.is_interested_in_user(user_id):
return True
return False
def get_app_service_by_user_id(self, user_id): def get_app_service_by_user_id(self, user_id):
"""Retrieve an application service from their user ID. """Retrieve an application service from their user ID.

View File

@ -242,7 +242,7 @@ class DeviceInboxStore(SQLBaseStore):
device_id(str): The recipient device_id. device_id(str): The recipient device_id.
up_to_stream_id(int): Where to delete messages up to. up_to_stream_id(int): Where to delete messages up to.
Returns: Returns:
A deferred that resolves when the messages have been deleted. A deferred that resolves to the number of messages deleted.
""" """
def delete_messages_for_device_txn(txn): def delete_messages_for_device_txn(txn):
sql = ( sql = (
@ -251,6 +251,7 @@ class DeviceInboxStore(SQLBaseStore):
" AND stream_id <= ?" " AND stream_id <= ?"
) )
txn.execute(sql, (user_id, device_id, up_to_stream_id)) txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
return self.runInteraction( return self.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn "delete_messages_for_device", delete_messages_for_device_txn
@ -269,27 +270,29 @@ class DeviceInboxStore(SQLBaseStore):
return defer.succeed([]) return defer.succeed([])
def get_all_new_device_messages_txn(txn): def get_all_new_device_messages_txn(txn):
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
upper_pos = min(current_pos, last_pos + limit)
sql = ( sql = (
"SELECT stream_id FROM device_inbox" "SELECT stream_id, user_id"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY stream_id"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (last_pos, current_pos, limit))
stream_ids = txn.fetchall()
if not stream_ids:
return []
max_stream_id_in_limit = stream_ids[-1]
sql = (
"SELECT stream_id, user_id, device_id, message_json"
" FROM device_inbox" " FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?" " WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC" " ORDER BY stream_id ASC"
) )
txn.execute(sql, (last_pos, max_stream_id_in_limit)) txn.execute(sql, (last_pos, upper_pos))
return txn.fetchall() rows = txn.fetchall()
sql = (
"SELECT stream_id, destination"
" FROM device_federation_outbox"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
)
txn.execute(sql, (last_pos, upper_pos))
rows.extend(txn.fetchall())
return rows
return self.runInteraction( return self.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn "get_all_new_device_messages", get_all_new_device_messages_txn

View File

@ -39,6 +39,14 @@ class EventPushActionsStore(SQLBaseStore):
columns=["user_id", "stream_ordering"], columns=["user_id", "stream_ordering"],
) )
self.register_background_index_update(
"event_push_actions_highlights_index",
index_name="event_push_actions_highlights_index",
table="event_push_actions",
columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
where_clause="highlight=1"
)
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
""" """
Args: Args:
@ -88,8 +96,11 @@ class EventPushActionsStore(SQLBaseStore):
topological_ordering, stream_ordering topological_ordering, stream_ordering
) )
# First get number of notifications.
# We don't need to put a notif=1 clause as all rows always have
# notif=1
sql = ( sql = (
"SELECT sum(notif), sum(highlight)" "SELECT count(*)"
" FROM event_push_actions ea" " FROM event_push_actions ea"
" WHERE" " WHERE"
" user_id = ?" " user_id = ?"
@ -99,13 +110,27 @@ class EventPushActionsStore(SQLBaseStore):
txn.execute(sql, (user_id, room_id)) txn.execute(sql, (user_id, room_id))
row = txn.fetchone() row = txn.fetchone()
if row: notify_count = row[0] if row else 0
return {
"notify_count": row[0] or 0, # Now get the number of highlights
"highlight_count": row[1] or 0, sql = (
} "SELECT count(*)"
else: " FROM event_push_actions ea"
return {"notify_count": 0, "highlight_count": 0} " WHERE"
" highlight = 1"
" AND user_id = ?"
" AND room_id = ?"
" AND %s"
) % (lower_bound(token, self.database_engine, inclusive=False),)
txn.execute(sql, (user_id, room_id))
row = txn.fetchone()
highlight_count = row[0] if row else 0
return {
"notify_count": notify_count,
"highlight_count": highlight_count,
}
ret = yield self.runInteraction( ret = yield self.runInteraction(
"get_unread_event_push_actions_by_room", "get_unread_event_push_actions_by_room",

View File

@ -54,6 +54,7 @@ def encode_json(json_object):
else: else:
return json.dumps(json_object, ensure_ascii=False) return json.dumps(json_object, ensure_ascii=False)
# These values are used in the `enqueus_event` and `_do_fetch` methods to # These values are used in the `enqueus_event` and `_do_fetch` methods to
# control how we batch/bulk fetch events from the database. # control how we batch/bulk fetch events from the database.
# The values are plucked out of thing air to make initial sync run faster # The values are plucked out of thing air to make initial sync run faster

View File

@ -16,6 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.api.errors import SynapseError, Codes
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
import simplejson as json import simplejson as json
@ -24,6 +25,13 @@ import simplejson as json
class FilteringStore(SQLBaseStore): class FilteringStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2) @cachedInlineCallbacks(num_args=2)
def get_user_filter(self, user_localpart, filter_id): def get_user_filter(self, user_localpart, filter_id):
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
# with a coherent error message rather than 500 M_UNKNOWN.
try:
int(filter_id)
except ValueError:
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
def_json = yield self._simple_select_one_onecol( def_json = yield self._simple_select_one_onecol(
table="user_filters", table="user_filters",
keyvalues={ keyvalues={

View File

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 38 SCHEMA_VERSION = 39
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View File

@ -37,6 +37,13 @@ class UserPresenceState(namedtuple("UserPresenceState",
status_msg (str): User set status message. status_msg (str): User set status message.
""" """
def as_dict(self):
return dict(self._asdict())
@staticmethod
def from_dict(d):
return UserPresenceState(**d)
def copy_and_replace(self, **kwargs): def copy_and_replace(self, **kwargs):
return self._replace(**kwargs) return self._replace(**kwargs)

View File

@ -156,12 +156,20 @@ class PushRuleStore(SQLBaseStore):
event=event, event=event,
) )
local_users_in_room = set(u for u in users_in_room if self.hs.is_mine_id(u)) # We ignore app service users for now. This is so that we don't fill
# up the `get_if_users_have_pushers` cache with AS entries that we
# know don't have pushers, nor even read receipts.
local_users_in_room = set(
u for u in users_in_room
if self.hs.is_mine_id(u)
and not self.get_if_app_services_interested_in_user(u)
)
# users in the room who have pushers need to get push rules run because # users in the room who have pushers need to get push rules run because
# that's how their pushers work # that's how their pushers work
if_users_with_pushers = yield self.get_if_users_have_pushers( if_users_with_pushers = yield self.get_if_users_have_pushers(
local_users_in_room, on_invalidate=cache_context.invalidate, local_users_in_room,
on_invalidate=cache_context.invalidate,
) )
user_ids = set( user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher

View File

@ -405,7 +405,7 @@ class ReceiptsStore(SQLBaseStore):
room_id, receipt_type, user_id, event_ids, data room_id, receipt_type, user_id, event_ids, data
) )
max_persisted_id = self._stream_id_gen.get_current_token() max_persisted_id = self._receipts_id_gen.get_current_token()
defer.returnValue((stream_id, max_persisted_id)) defer.returnValue((stream_id, max_persisted_id))

View File

@ -68,31 +68,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
desc="add_access_token_to_user", desc="add_access_token_to_user",
) )
@defer.inlineCallbacks
def add_refresh_token_to_user(self, user_id, token, device_id=None):
"""Adds a refresh token for the given user.
Args:
user_id (str): The user ID.
token (str): The new refresh token to add.
device_id (str): ID of the device to associate with the access
token
Raises:
StoreError if there was a problem adding this.
"""
next_id = self._refresh_tokens_id_gen.get_next()
yield self._simple_insert(
"refresh_tokens",
{
"id": next_id,
"user_id": user_id,
"token": token,
"device_id": device_id,
},
desc="add_refresh_token_to_user",
)
def register(self, user_id, token=None, password_hash=None, def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None, was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_localpart=None, admin=False): create_profile_with_localpart=None, admin=False):
@ -353,47 +328,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
token token
) )
def exchange_refresh_token(self, refresh_token, token_generator):
"""Exchange a refresh token for a new one.
Doing so invalidates the old refresh token - refresh tokens are single
use.
Args:
refresh_token (str): The refresh token of a user.
token_generator (fn: str -> str): Function which, when given a
user ID, returns a unique refresh token for that user. This
function must never return the same value twice.
Returns:
tuple of (user_id, new_refresh_token, device_id)
Raises:
StoreError if no user was found with that refresh token.
"""
return self.runInteraction(
"exchange_refresh_token",
self._exchange_refresh_token,
refresh_token,
token_generator
)
def _exchange_refresh_token(self, txn, old_token, token_generator):
sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?"
txn.execute(sql, (old_token,))
rows = self.cursor_to_dict(txn)
if not rows:
raise StoreError(403, "Did not recognize refresh token")
user_id = rows[0]["user_id"]
device_id = rows[0]["device_id"]
# TODO(danielwh): Maybe perform a validation on the macaroon that
# macaroon.user_id == user_id.
new_token = token_generator(user_id)
sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
txn.execute(sql, (new_token, old_token,))
return user_id, new_token, device_id
@defer.inlineCallbacks @defer.inlineCallbacks
def is_server_admin(self, user): def is_server_admin(self, user):
res = yield self._simple_select_one_onecol( res = yield self._simple_select_one_onecol(

View File

@ -16,6 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore from ._base import SQLBaseStore
from .engines import PostgresEngine, Sqlite3Engine from .engines import PostgresEngine, Sqlite3Engine
@ -106,7 +107,11 @@ class RoomStore(SQLBaseStore):
entries = self._simple_select_list_txn( entries = self._simple_select_list_txn(
txn, txn,
table="public_room_list_stream", table="public_room_list_stream",
keyvalues={"room_id": room_id}, keyvalues={
"room_id": room_id,
"appservice_id": None,
"network_id": None,
},
retcols=("stream_id", "visibility"), retcols=("stream_id", "visibility"),
) )
@ -124,6 +129,8 @@ class RoomStore(SQLBaseStore):
"stream_id": next_id, "stream_id": next_id,
"room_id": room_id, "room_id": room_id,
"visibility": is_public, "visibility": is_public,
"appservice_id": None,
"network_id": None,
} }
) )
@ -132,6 +139,87 @@ class RoomStore(SQLBaseStore):
"set_room_is_public", "set_room_is_public",
set_room_is_public_txn, next_id, set_room_is_public_txn, next_id,
) )
self.hs.get_notifier().on_new_replication_data()
@defer.inlineCallbacks
def set_room_is_public_appservice(self, room_id, appservice_id, network_id,
is_public):
"""Edit the appservice/network specific public room list.
Each appservice can have a number of published room lists associated
with them, keyed off of an appservice defined `network_id`, which
basically represents a single instance of a bridge to a third party
network.
Args:
room_id (str)
appservice_id (str)
network_id (str)
is_public (bool): Whether to publish or unpublish the room from the
list.
"""
def set_room_is_public_appservice_txn(txn, next_id):
if is_public:
try:
self._simple_insert_txn(
txn,
table="appservice_room_list",
values={
"appservice_id": appservice_id,
"network_id": network_id,
"room_id": room_id
},
)
except self.database_engine.module.IntegrityError:
# We've already inserted, nothing to do.
return
else:
self._simple_delete_txn(
txn,
table="appservice_room_list",
keyvalues={
"appservice_id": appservice_id,
"network_id": network_id,
"room_id": room_id
},
)
entries = self._simple_select_list_txn(
txn,
table="public_room_list_stream",
keyvalues={
"room_id": room_id,
"appservice_id": appservice_id,
"network_id": network_id,
},
retcols=("stream_id", "visibility"),
)
entries.sort(key=lambda r: r["stream_id"])
add_to_stream = True
if entries:
add_to_stream = bool(entries[-1]["visibility"]) != is_public
if add_to_stream:
self._simple_insert_txn(
txn,
table="public_room_list_stream",
values={
"stream_id": next_id,
"room_id": room_id,
"visibility": is_public,
"appservice_id": appservice_id,
"network_id": network_id,
}
)
with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn, next_id,
)
self.hs.get_notifier().on_new_replication_data()
def get_public_room_ids(self): def get_public_room_ids(self):
return self._simple_select_onecol( return self._simple_select_onecol(
@ -259,38 +347,96 @@ class RoomStore(SQLBaseStore):
def get_current_public_room_stream_id(self): def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token() return self._public_room_id_gen.get_current_token()
def get_public_room_ids_at_stream_id(self, stream_id): @cached(num_args=2, max_entries=100)
def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
"""Get pulbic rooms for a particular list, or across all lists.
Args:
stream_id (int)
network_tuple (ThirdPartyInstanceID): The list to use (None, None)
means the main list, None means all lsits.
"""
return self.runInteraction( return self.runInteraction(
"get_public_room_ids_at_stream_id", "get_public_room_ids_at_stream_id",
self.get_public_room_ids_at_stream_id_txn, stream_id self.get_public_room_ids_at_stream_id_txn,
stream_id, network_tuple=network_tuple
) )
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id): def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
network_tuple):
return { return {
rm rm
for rm, vis in self.get_published_at_stream_id_txn(txn, stream_id).items() for rm, vis in self.get_published_at_stream_id_txn(
txn, stream_id, network_tuple=network_tuple
).items()
if vis if vis
} }
def get_published_at_stream_id_txn(self, txn, stream_id): def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
sql = (""" if network_tuple:
SELECT room_id, visibility FROM public_room_list_stream # We want to get from a particular list. No aggregation required.
INNER JOIN (
SELECT room_id, max(stream_id) AS stream_id sql = ("""
SELECT room_id, visibility FROM public_room_list_stream
INNER JOIN (
SELECT room_id, max(stream_id) AS stream_id
FROM public_room_list_stream
WHERE stream_id <= ? %s
GROUP BY room_id
) grouped USING (room_id, stream_id)
""")
if network_tuple.appservice_id is not None:
txn.execute(
sql % ("AND appservice_id = ? AND network_id = ?",),
(stream_id, network_tuple.appservice_id, network_tuple.network_id,)
)
else:
txn.execute(
sql % ("AND appservice_id IS NULL",),
(stream_id,)
)
return dict(txn.fetchall())
else:
# We want to get from all lists, so we need to aggregate the results
logger.info("Executing full list")
sql = ("""
SELECT room_id, visibility
FROM public_room_list_stream FROM public_room_list_stream
WHERE stream_id <= ? INNER JOIN (
GROUP BY room_id SELECT
) grouped USING (room_id, stream_id) room_id, max(stream_id) AS stream_id, appservice_id,
""") network_id
FROM public_room_list_stream
WHERE stream_id <= ?
GROUP BY room_id, appservice_id, network_id
) grouped USING (room_id, stream_id)
""")
txn.execute(sql, (stream_id,)) txn.execute(
return dict(txn.fetchall()) sql,
(stream_id,)
)
def get_public_room_changes(self, prev_stream_id, new_stream_id): results = {}
# A room is visible if its visible on any list.
for room_id, visibility in txn.fetchall():
results[room_id] = bool(visibility) or results.get(room_id, False)
return results
def get_public_room_changes(self, prev_stream_id, new_stream_id,
network_tuple):
def get_public_room_changes_txn(txn): def get_public_room_changes_txn(txn):
then_rooms = self.get_public_room_ids_at_stream_id_txn(txn, prev_stream_id) then_rooms = self.get_public_room_ids_at_stream_id_txn(
txn, prev_stream_id, network_tuple
)
now_rooms_dict = self.get_published_at_stream_id_txn(txn, new_stream_id) now_rooms_dict = self.get_published_at_stream_id_txn(
txn, new_stream_id, network_tuple
)
now_rooms_visible = set( now_rooms_visible = set(
rm for rm, vis in now_rooms_dict.items() if vis rm for rm, vis in now_rooms_dict.items() if vis
@ -311,7 +457,8 @@ class RoomStore(SQLBaseStore):
def get_all_new_public_rooms(self, prev_id, current_id, limit): def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn): def get_all_new_public_rooms(txn):
sql = (""" sql = ("""
SELECT stream_id, room_id, visibility FROM public_room_list_stream SELECT stream_id, room_id, visibility, appservice_id, network_id
FROM public_room_list_stream
WHERE stream_id > ? AND stream_id <= ? WHERE stream_id > ? AND stream_id <= ?
ORDER BY stream_id ASC ORDER BY stream_id ASC
LIMIT ? LIMIT ?

View File

@ -24,6 +24,7 @@ from synapse.api.constants import Membership, EventTypes
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
import logging import logging
import ujson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,7 +35,15 @@ RoomsForUser = namedtuple(
) )
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
class RoomMemberStore(SQLBaseStore): class RoomMemberStore(SQLBaseStore):
def __init__(self, hs):
super(RoomMemberStore, self).__init__(hs)
self.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
)
def _store_room_members_txn(self, txn, events, backfilled): def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database. """Store a room member in the database.
@ -49,6 +58,8 @@ class RoomMemberStore(SQLBaseStore):
"sender": event.user_id, "sender": event.user_id,
"room_id": event.room_id, "room_id": event.room_id,
"membership": event.membership, "membership": event.membership,
"display_name": event.content.get("displayname", None),
"avatar_url": event.content.get("avatar_url", None),
} }
for event in events for event in events
] ]
@ -398,7 +409,7 @@ class RoomMemberStore(SQLBaseStore):
table="room_memberships", table="room_memberships",
column="event_id", column="event_id",
iterable=member_event_ids, iterable=member_event_ids,
retcols=['user_id'], retcols=['user_id', 'display_name', 'avatar_url'],
keyvalues={ keyvalues={
"membership": Membership.JOIN, "membership": Membership.JOIN,
}, },
@ -406,11 +417,21 @@ class RoomMemberStore(SQLBaseStore):
desc="_get_joined_users_from_context", desc="_get_joined_users_from_context",
) )
users_in_room = set(row["user_id"] for row in rows) users_in_room = {
row["user_id"]: {
"display_name": row["display_name"],
"avatar_url": row["avatar_url"],
}
for row in rows
}
if event is not None and event.type == EventTypes.Member: if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if event.event_id in member_event_ids: if event.event_id in member_event_ids:
users_in_room.add(event.state_key) users_in_room[event.state_key] = {
"display_name": event.content.get("displayname", None),
"avatar_url": event.content.get("avatar_url", None),
}
defer.returnValue(users_in_room) defer.returnValue(users_in_room)
@ -448,3 +469,78 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(True) defer.returnValue(True)
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks
def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get(
"target_min_stream_id_inclusive", self._min_stream_order_on_start
)
max_stream_id = progress.get(
"max_stream_id_exclusive", self._stream_order_on_start + 1
)
INSERT_CLUMP_SIZE = 1000
def add_membership_profile_txn(txn):
sql = ("""
SELECT stream_ordering, event_id, events.room_id, content
FROM events
INNER JOIN room_memberships USING (event_id)
WHERE ? <= stream_ordering AND stream_ordering < ?
AND type = 'm.room.member'
ORDER BY stream_ordering DESC
LIMIT ?
""")
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
rows = self.cursor_to_dict(txn)
if not rows:
return 0
min_stream_id = rows[-1]["stream_ordering"]
to_update = []
for row in rows:
event_id = row["event_id"]
room_id = row["room_id"]
try:
content = json.loads(row["content"])
except:
continue
display_name = content.get("displayname", None)
avatar_url = content.get("avatar_url", None)
if display_name or avatar_url:
to_update.append((
display_name, avatar_url, event_id, room_id
))
to_update_sql = ("""
UPDATE room_memberships SET display_name = ?, avatar_url = ?
WHERE event_id = ? AND room_id = ?
""")
for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
clump = to_update[index:index + INSERT_CLUMP_SIZE]
txn.executemany(to_update_sql, clump)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id,
}
self._background_update_progress_txn(
txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
)
return len(rows)
result = yield self.runInteraction(
_MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
)
if not result:
yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
defer.returnValue(result)

View File

@ -0,0 +1,29 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE appservice_room_list(
appservice_id TEXT NOT NULL,
network_id TEXT NOT NULL,
room_id TEXT NOT NULL
);
-- Each appservice can have multiple published room lists associated with them,
-- keyed of a particular network_id
CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list(
appservice_id, network_id, room_id
);
ALTER TABLE public_room_list_stream ADD COLUMN appservice_id TEXT;
ALTER TABLE public_room_list_stream ADD COLUMN network_id TEXT;

View File

@ -0,0 +1,16 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX device_federation_outbox_id ON device_federation_outbox(stream_id);

View File

@ -0,0 +1,17 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('event_push_actions_highlights_index', '{}');

View File

@ -0,0 +1,22 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE federation_stream_position(
type TEXT NOT NULL,
stream_id INTEGER NOT NULL
);
INSERT INTO federation_stream_position (type, stream_id) VALUES ('federation', -1);
INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coalesce(max(stream_ordering), -1) FROM events;

View File

@ -0,0 +1,20 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE room_memberships ADD COLUMN display_name TEXT;
ALTER TABLE room_memberships ADD COLUMN avatar_url TEXT;
INSERT into background_updates (update_name, progress_json)
VALUES ('room_membership_profile_update', '{}');

View File

@ -653,7 +653,10 @@ class StateStore(SQLBaseStore):
else: else:
state_dict = results[group] state_dict = results[group]
state_dict.update(group_state_dict) state_dict.update({
(intern_string(k[0]), intern_string(k[1])): v
for k, v in group_state_dict.items()
})
self._state_group_cache.update( self._state_group_cache.update(
cache_seq_num, cache_seq_num,

View File

@ -541,6 +541,9 @@ class StreamStore(SQLBaseStore):
def get_room_max_stream_ordering(self): def get_room_max_stream_ordering(self):
return self._stream_id_gen.get_current_token() return self._stream_id_gen.get_current_token()
def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token()
def get_stream_token_for_event(self, event_id): def get_stream_token_for_event(self, event_id):
"""The stream token for an event """The stream token for an event
Args: Args:
@ -765,3 +768,50 @@ class StreamStore(SQLBaseStore):
"token": end_token, "token": end_token,
}, },
} }
@defer.inlineCallbacks
def get_all_new_events_stream(self, from_id, current_id, limit):
"""Get all new events"""
def get_all_new_events_stream_txn(txn):
sql = (
"SELECT e.stream_ordering, e.event_id"
" FROM events AS e"
" WHERE"
" ? < e.stream_ordering AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (from_id, current_id, limit))
rows = txn.fetchall()
upper_bound = current_id
if len(rows) == limit:
upper_bound = rows[-1][0]
return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn,
)
events = yield self._get_events(event_ids)
defer.returnValue((upper_bound, events))
def get_federation_out_pos(self, typ):
return self._simple_select_one_onecol(
table="federation_stream_position",
retcol="stream_id",
keyvalues={"type": typ},
desc="get_federation_out_pos"
)
def update_federation_out_pos(self, typ, stream_id):
return self._simple_update_one(
table="federation_stream_position",
keyvalues={"type": typ},
updatevalues={"stream_id": stream_id},
desc="update_federation_out_pos",
)

View File

@ -200,25 +200,48 @@ class TransactionStore(SQLBaseStore):
def _set_destination_retry_timings(self, txn, destination, def _set_destination_retry_timings(self, txn, destination,
retry_last_ts, retry_interval): retry_last_ts, retry_interval):
txn.call_after(self.get_destination_retry_timings.invalidate, (destination,)) self.database_engine.lock_table(txn, "destinations")
self._simple_upsert_txn( self._invalidate_cache_and_stream(
txn, self.get_destination_retry_timings, (destination,)
)
# We need to be careful here as the data may have changed from under us
# due to a worker setting the timings.
prev_row = self._simple_select_one_txn(
txn, txn,
"destinations", table="destinations",
keyvalues={ keyvalues={
"destination": destination, "destination": destination,
}, },
values={ retcols=("retry_last_ts", "retry_interval"),
"retry_last_ts": retry_last_ts, allow_none=True,
"retry_interval": retry_interval,
},
insertion_values={
"destination": destination,
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
}
) )
if not prev_row:
self._simple_insert_txn(
txn,
table="destinations",
values={
"destination": destination,
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
}
)
elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
self._simple_update_one_txn(
txn,
"destinations",
keyvalues={
"destination": destination,
},
updatevalues={
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
},
)
def get_destinations_needing_retry(self): def get_destinations_needing_retry(self):
"""Get all destinations which are due a retry for sending a transaction. """Get all destinations which are due a retry for sending a transaction.

View File

@ -274,3 +274,37 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "t%d-%d" % (self.topological, self.stream) return "t%d-%d" % (self.topological, self.stream)
else: else:
return "s%d" % (self.stream,) return "s%d" % (self.stream,)
class ThirdPartyInstanceID(
namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
):
# Deny iteration because it will bite you if you try to create a singleton
# set by:
# users = set(user)
def __iter__(self):
raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
# Because this class is a namedtuple of strings, it is deeply immutable.
def __copy__(self):
return self
def __deepcopy__(self, memo):
return self
@classmethod
def from_string(cls, s):
bits = s.split("|", 2)
if len(bits) != 2:
raise SynapseError(400, "Invalid ID %r" % (s,))
return cls(appservice_id=bits[0], network_id=bits[1])
def to_string(self):
return "%s|%s" % (self.appservice_id, self.network_id,)
__str__ = to_string
@classmethod
def create(cls, appservice_id, network_id,):
return cls(appservice_id=appservice_id, network_id=network_id)

View File

@ -24,6 +24,11 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeferredTimedOutError(SynapseError):
def __init__(self):
super(SynapseError).__init__(504, "Timed out")
def unwrapFirstError(failure): def unwrapFirstError(failure):
# defer.gatherResults and DeferredLists wrap failures. # defer.gatherResults and DeferredLists wrap failures.
failure.trap(defer.FirstError) failure.trap(defer.FirstError)
@ -89,7 +94,7 @@ class Clock(object):
def timed_out_fn(): def timed_out_fn():
try: try:
ret_deferred.errback(SynapseError(504, "Timed out")) ret_deferred.errback(DeferredTimedOutError())
except: except:
pass pass

View File

@ -197,6 +197,64 @@ class Linearizer(object):
defer.returnValue(_ctx_manager()) defer.returnValue(_ctx_manager())
class Limiter(object):
"""Limits concurrent access to resources based on a key. Useful to ensure
only a few thing happen at a time on a given resource.
Example:
with (yield limiter.queue("test_key")):
# do some work.
"""
def __init__(self, max_count):
"""
Args:
max_count(int): The maximum number of concurrent access
"""
self.max_count = max_count
# key_to_defer is a map from the key to a 2 element list where
# the first element is the number of things executing
# the second element is a list of deferreds for the things blocked from
# executing.
self.key_to_defer = {}
@defer.inlineCallbacks
def queue(self, key):
entry = self.key_to_defer.setdefault(key, [0, []])
# If the number of things executing is greater than the maximum
# then add a deferred to the list of blocked items
# When on of the things currently executing finishes it will callback
# this item so that it can continue executing.
if entry[0] >= self.max_count:
new_defer = defer.Deferred()
entry[1].append(new_defer)
with PreserveLoggingContext():
yield new_defer
entry[0] += 1
@contextmanager
def _ctx_manager():
try:
yield
finally:
# We've finished executing so check if there are any things
# blocked waiting to execute and start one of them
entry[0] -= 1
try:
entry[1].pop(0).callback(None)
except IndexError:
# If nothing else is executing for this key then remove it
# from the map
if entry[0] == 0:
self.key_to_defer.pop(key, None)
defer.returnValue(_ctx_manager())
class ReadWriteLock(object): class ReadWriteLock(object):
"""A deferred style read write lock. """A deferred style read write lock.

View File

@ -76,15 +76,26 @@ class JsonEncodedObject(object):
d.update(self.unrecognized_keys) d.update(self.unrecognized_keys)
return d return d
def get_internal_dict(self):
d = {
k: _encode(v, internal=True) for (k, v) in self.__dict__.items()
if k in self.valid_keys
}
d.update(self.unrecognized_keys)
return d
def __str__(self): def __str__(self):
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__)) return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
def _encode(obj): def _encode(obj, internal=False):
if type(obj) is list: if type(obj) is list:
return [_encode(o) for o in obj] return [_encode(o, internal=internal) for o in obj]
if isinstance(obj, JsonEncodedObject): if isinstance(obj, JsonEncodedObject):
return obj.get_dict() if internal:
return obj.get_internal_dict()
else:
return obj.get_dict()
return obj return obj

View File

@ -1,369 +0,0 @@
from twisted.internet import defer
from synapse.config._base import ConfigError
from synapse.types import UserID
import ldap3
import ldap3.core.exceptions
import logging
try:
import ldap3
import ldap3.core.exceptions
except ImportError:
ldap3 = None
pass
logger = logging.getLogger(__name__)
class LDAPMode(object):
SIMPLE = "simple",
SEARCH = "search",
LIST = (SIMPLE, SEARCH)
class LdapAuthProvider(object):
__version__ = "0.1"
def __init__(self, config, account_handler):
self.account_handler = account_handler
if not ldap3:
raise RuntimeError(
'Missing ldap3 library. This is required for LDAP Authentication.'
)
self.ldap_mode = config.mode
self.ldap_uri = config.uri
self.ldap_start_tls = config.start_tls
self.ldap_base = config.base
self.ldap_attributes = config.attributes
if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = config.bind_dn
self.ldap_bind_password = config.bind_password
self.ldap_filter = config.filter
@defer.inlineCallbacks
def check_password(self, user_id, password):
""" Attempt to authenticate a user against an LDAP Server
and register an account if none exists.
Returns:
True if authentication against LDAP was successful
"""
localpart = UserID.from_string(user_id).localpart
try:
server = ldap3.Server(self.ldap_uri)
logger.debug(
"Attempting LDAP connection with %s",
self.ldap_uri
)
if self.ldap_mode == LDAPMode.SIMPLE:
result, conn = self._ldap_simple_bind(
server=server, localpart=localpart, password=password
)
logger.debug(
'LDAP authentication method simple bind returned: %s (conn: %s)',
result,
conn
)
if not result:
defer.returnValue(False)
elif self.ldap_mode == LDAPMode.SEARCH:
result, conn = self._ldap_authenticated_search(
server=server, localpart=localpart, password=password
)
logger.debug(
'LDAP auth method authenticated search returned: %s (conn: %s)',
result,
conn
)
if not result:
defer.returnValue(False)
else:
raise RuntimeError(
'Invalid LDAP mode specified: {mode}'.format(
mode=self.ldap_mode
)
)
try:
logger.info(
"User authenticated against LDAP server: %s",
conn
)
except NameError:
logger.warn(
"Authentication method yielded no LDAP connection, aborting!"
)
defer.returnValue(False)
# check if user with user_id exists
if (yield self.account_handler.check_user_exists(user_id)):
# exists, authentication complete
conn.unbind()
defer.returnValue(True)
else:
# does not exist, fetch metadata for account creation from
# existing ldap connection
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
query = "(&{filter}{user_filter})".format(
filter=query,
user_filter=self.ldap_filter
)
logger.debug(
"ldap registration filter: %s",
query
)
conn.search(
search_base=self.ldap_base,
search_filter=query,
attributes=[
self.ldap_attributes['name'],
self.ldap_attributes['mail']
]
)
if len(conn.response) == 1:
attrs = conn.response[0]['attributes']
mail = attrs[self.ldap_attributes['mail']][0]
name = attrs[self.ldap_attributes['name']][0]
# create account
user_id, access_token = (
yield self.account_handler.register(localpart=localpart)
)
# TODO: bind email, set displayname with data from ldap directory
logger.info(
"Registration based on LDAP data was successful: %d: %s (%s, %)",
user_id,
localpart,
name,
mail
)
defer.returnValue(True)
else:
if len(conn.response) == 0:
logger.warn("LDAP registration failed, no result.")
else:
logger.warn(
"LDAP registration failed, too many results (%s)",
len(conn.response)
)
defer.returnValue(False)
defer.returnValue(False)
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during ldap authentication: %s", e)
defer.returnValue(False)
@staticmethod
def parse_config(config):
class _LdapConfig(object):
pass
ldap_config = _LdapConfig()
ldap_config.enabled = config.get("enabled", False)
ldap_config.mode = LDAPMode.SIMPLE
# verify config sanity
_require_keys(config, [
"uri",
"base",
"attributes",
])
ldap_config.uri = config["uri"]
ldap_config.start_tls = config.get("start_tls", False)
ldap_config.base = config["base"]
ldap_config.attributes = config["attributes"]
if "bind_dn" in config:
ldap_config.mode = LDAPMode.SEARCH
_require_keys(config, [
"bind_dn",
"bind_password",
])
ldap_config.bind_dn = config["bind_dn"]
ldap_config.bind_password = config["bind_password"]
ldap_config.filter = config.get("filter", None)
# verify attribute lookup
_require_keys(config['attributes'], [
"uid",
"name",
"mail",
])
return ldap_config
def _ldap_simple_bind(self, server, localpart, password):
""" Attempt a simple bind with the credentials
given by the user against the LDAP server.
Returns True, LDAP3Connection
if the bind was successful
Returns False, None
if an error occured
"""
try:
# bind with the the local users ldap credentials
bind_dn = "{prop}={value},{base}".format(
prop=self.ldap_attributes['uid'],
value=localpart,
base=self.ldap_base
)
conn = ldap3.Connection(server, bind_dn, password,
authentication=ldap3.AUTH_SIMPLE)
logger.debug(
"Established LDAP connection in simple bind mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded LDAP connection in simple bind mode through StartTLS: %s",
conn
)
if conn.bind():
# GOOD: bind okay
logger.debug("LDAP Bind successful in simple bind mode.")
return True, conn
# BAD: bind failed
logger.info(
"Binding against LDAP failed for '%s' failed: %s",
localpart, conn.result['description']
)
conn.unbind()
return False, None
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during LDAP authentication: %s", e)
return False, None
def _ldap_authenticated_search(self, server, localpart, password):
""" Attempt to login with the preconfigured bind_dn
and then continue searching and filtering within
the base_dn
Returns (True, LDAP3Connection)
if a single matching DN within the base was found
that matched the filter expression, and with which
a successful bind was achieved
The LDAP3Connection returned is the instance that was used to
verify the password not the one using the configured bind_dn.
Returns (False, None)
if an error occured
"""
try:
conn = ldap3.Connection(
server,
self.ldap_bind_dn,
self.ldap_bind_password
)
logger.debug(
"Established LDAP connection in search mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded LDAP connection in search mode through StartTLS: %s",
conn
)
if not conn.bind():
logger.warn(
"Binding against LDAP with `bind_dn` failed: %s",
conn.result['description']
)
conn.unbind()
return False, None
# construct search_filter like (uid=localpart)
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_filter:
# combine with the AND expression
query = "(&{query}{filter})".format(
query=query,
filter=self.ldap_filter
)
logger.debug(
"LDAP search filter: %s",
query
)
conn.search(
search_base=self.ldap_base,
search_filter=query
)
if len(conn.response) == 1:
# GOOD: found exactly one result
user_dn = conn.response[0]['dn']
logger.debug('LDAP search found dn: %s', user_dn)
# unbind and simple bind with user_dn to verify the password
# Note: do not use rebind(), for some reason it did not verify
# the password for me!
conn.unbind()
return self._ldap_simple_bind(server, localpart, password)
else:
# BAD: found 0 or > 1 results, abort!
if len(conn.response) == 0:
logger.info(
"LDAP search returned no results for '%s'",
localpart
)
else:
logger.info(
"LDAP search returned too many (%s) results for '%s'",
len(conn.response), localpart
)
conn.unbind()
return False, None
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during LDAP authentication: %s", e)
return False, None
def _require_keys(config, required):
missing = [key for key in required if key not in config]
if missing:
raise ConfigError(
"LDAP enabled but missing required config values: {}".format(
", ".join(missing)
)
)

View File

@ -121,15 +121,9 @@ class RetryDestinationLimiter(object):
pass pass
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
def err(failure):
logger.exception(
"Failed to store set_destination_retry_timings",
failure.value
)
valid_err_code = False valid_err_code = False
if exc_type is not None and issubclass(exc_type, CodeMessageException): if exc_type is not None and issubclass(exc_type, CodeMessageException):
valid_err_code = 0 <= exc_val.code < 500 valid_err_code = exc_val.code != 429 and 0 <= exc_val.code < 500
if exc_type is None or valid_err_code: if exc_type is None or valid_err_code:
# We connected successfully. # We connected successfully.
@ -151,6 +145,15 @@ class RetryDestinationLimiter(object):
retry_last_ts = int(self.clock.time_msec()) retry_last_ts = int(self.clock.time_msec())
self.store.set_destination_retry_timings( @defer.inlineCallbacks
self.destination, retry_last_ts, self.retry_interval def store_retry_timings():
).addErrback(err) try:
yield self.store.set_destination_retry_timings(
self.destination, retry_last_ts, self.retry_interval
)
except:
logger.exception(
"Failed to store set_destination_retry_timings",
)
store_retry_timings()

View File

@ -12,17 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from tests import unittest
import pymacaroons
from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from mock import Mock import synapse.handlers.auth
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.types import UserID from synapse.types import UserID
from tests import unittest
from tests.utils import setup_test_homeserver, mock_getRawHeaders from tests.utils import setup_test_homeserver, mock_getRawHeaders
import pymacaroons
class TestHandlers(object):
def __init__(self, hs):
self.auth_handler = synapse.handlers.auth.AuthHandler(hs)
class AuthTestCase(unittest.TestCase): class AuthTestCase(unittest.TestCase):
@ -34,14 +39,17 @@ class AuthTestCase(unittest.TestCase):
self.hs = yield setup_test_homeserver(handlers=None) self.hs = yield setup_test_homeserver(handlers=None)
self.hs.get_datastore = Mock(return_value=self.store) self.hs.get_datastore = Mock(return_value=self.store)
self.hs.handlers = TestHandlers(self.hs)
self.auth = Auth(self.hs) self.auth = Auth(self.hs)
self.test_user = "@foo:bar" self.test_user = "@foo:bar"
self.test_token = "_test_token_" self.test_token = "_test_token_"
# this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self): def test_get_user_by_req_user_valid_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
user_info = { user_info = {
"name": self.test_user, "name": self.test_user,
"token_id": "ditto", "token_id": "ditto",
@ -56,7 +64,6 @@ class AuthTestCase(unittest.TestCase):
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self): def test_get_user_by_req_user_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
@ -66,7 +73,6 @@ class AuthTestCase(unittest.TestCase):
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
def test_get_user_by_req_user_missing_token(self): def test_get_user_by_req_user_missing_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
user_info = { user_info = {
"name": self.test_user, "name": self.test_user,
"token_id": "ditto", "token_id": "ditto",
@ -158,7 +164,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize()) user_info = yield self.auth.get_user_by_access_token(macaroon.serialize())
user = user_info["user"] user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user) self.assertEqual(UserID.from_string(user_id), user)
@ -168,6 +174,10 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self): def test_get_guest_user_from_macaroon(self):
self.store.get_user_by_id = Mock(return_value={
"is_guest": True,
})
user_id = "@baldrick:matrix.org" user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
@ -179,11 +189,12 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("guest = true") macaroon.add_first_party_caveat("guest = true")
serialized = macaroon.serialize() serialized = macaroon.serialize()
user_info = yield self.auth.get_user_from_macaroon(serialized) user_info = yield self.auth.get_user_by_access_token(serialized)
user = user_info["user"] user = user_info["user"]
is_guest = user_info["is_guest"] is_guest = user_info["is_guest"]
self.assertEqual(UserID.from_string(user_id), user) self.assertEqual(UserID.from_string(user_id), user)
self.assertTrue(is_guest) self.assertTrue(is_guest)
self.store.get_user_by_id.assert_called_with(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_from_macaroon_user_db_mismatch(self): def test_get_user_from_macaroon_user_db_mismatch(self):
@ -200,7 +211,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,)) macaroon.add_first_party_caveat("user_id = %s" % (user,))
with self.assertRaises(AuthError) as cm: with self.assertRaises(AuthError) as cm:
yield self.auth.get_user_from_macaroon(macaroon.serialize()) yield self.auth.get_user_by_access_token(macaroon.serialize())
self.assertEqual(401, cm.exception.code) self.assertEqual(401, cm.exception.code)
self.assertIn("User mismatch", cm.exception.msg) self.assertIn("User mismatch", cm.exception.msg)
@ -220,7 +231,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
with self.assertRaises(AuthError) as cm: with self.assertRaises(AuthError) as cm:
yield self.auth.get_user_from_macaroon(macaroon.serialize()) yield self.auth.get_user_by_access_token(macaroon.serialize())
self.assertEqual(401, cm.exception.code) self.assertEqual(401, cm.exception.code)
self.assertIn("No user caveat", cm.exception.msg) self.assertIn("No user caveat", cm.exception.msg)
@ -242,7 +253,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("user_id = %s" % (user,)) macaroon.add_first_party_caveat("user_id = %s" % (user,))
with self.assertRaises(AuthError) as cm: with self.assertRaises(AuthError) as cm:
yield self.auth.get_user_from_macaroon(macaroon.serialize()) yield self.auth.get_user_by_access_token(macaroon.serialize())
self.assertEqual(401, cm.exception.code) self.assertEqual(401, cm.exception.code)
self.assertIn("Invalid macaroon", cm.exception.msg) self.assertIn("Invalid macaroon", cm.exception.msg)
@ -265,7 +276,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("cunning > fox") macaroon.add_first_party_caveat("cunning > fox")
with self.assertRaises(AuthError) as cm: with self.assertRaises(AuthError) as cm:
yield self.auth.get_user_from_macaroon(macaroon.serialize()) yield self.auth.get_user_by_access_token(macaroon.serialize())
self.assertEqual(401, cm.exception.code) self.assertEqual(401, cm.exception.code)
self.assertIn("Invalid macaroon", cm.exception.msg) self.assertIn("Invalid macaroon", cm.exception.msg)
@ -293,12 +304,12 @@ class AuthTestCase(unittest.TestCase):
self.hs.clock.now = 5000 # seconds self.hs.clock.now = 5000 # seconds
self.hs.config.expire_access_token = True self.hs.config.expire_access_token = True
# yield self.auth.get_user_from_macaroon(macaroon.serialize()) # yield self.auth.get_user_by_access_token(macaroon.serialize())
# TODO(daniel): Turn on the check that we validate expiration, when we # TODO(daniel): Turn on the check that we validate expiration, when we
# validate expiration (and remove the above line, which will start # validate expiration (and remove the above line, which will start
# throwing). # throwing).
with self.assertRaises(AuthError) as cm: with self.assertRaises(AuthError) as cm:
yield self.auth.get_user_from_macaroon(macaroon.serialize()) yield self.auth.get_user_by_access_token(macaroon.serialize())
self.assertEqual(401, cm.exception.code) self.assertEqual(401, cm.exception.code)
self.assertIn("Invalid macaroon", cm.exception.msg) self.assertIn("Invalid macaroon", cm.exception.msg)
@ -327,6 +338,58 @@ class AuthTestCase(unittest.TestCase):
self.hs.clock.now = 5000 # seconds self.hs.clock.now = 5000 # seconds
self.hs.config.expire_access_token = True self.hs.config.expire_access_token = True
user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize()) user_info = yield self.auth.get_user_by_access_token(macaroon.serialize())
user = user_info["user"] user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user) self.assertEqual(UserID.from_string(user_id), user)
@defer.inlineCallbacks
def test_cannot_use_regular_token_as_guest(self):
USER_ID = "@percy:matrix.org"
self.store.add_access_token_to_user = Mock()
token = yield self.hs.handlers.auth_handler.issue_access_token(
USER_ID, "DEVICE"
)
self.store.add_access_token_to_user.assert_called_with(
USER_ID, token, "DEVICE"
)
def get_user(tok):
if token != tok:
return None
return {
"name": USER_ID,
"is_guest": False,
"token_id": 1234,
"device_id": "DEVICE",
}
self.store.get_user_by_access_token = get_user
self.store.get_user_by_id = Mock(return_value={
"is_guest": False,
})
# check the token works
request = Mock(args={})
request.args["access_token"] = [token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
self.assertEqual(UserID.from_string(USER_ID), requester.user)
self.assertFalse(requester.is_guest)
# add an is_guest caveat
mac = pymacaroons.Macaroon.deserialize(token)
mac.add_first_party_caveat("guest = true")
guest_tok = mac.serialize()
# the token should *not* work now
request = Mock(args={})
request.args["access_token"] = [guest_tok]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
with self.assertRaises(AuthError) as cm:
yield self.auth.get_user_by_req(request, allow_guest=True)
self.assertEqual(401, cm.exception.code)
self.assertEqual("Guest access token used for regular user", cm.exception.msg)
self.store.get_user_by_id.assert_called_with(USER_ID)

View File

@ -17,7 +17,11 @@
from .. import unittest from .. import unittest
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.utils import prune_event from synapse.events.utils import prune_event, serialize_event
def MockEvent(**kwargs):
return FrozenEvent(kwargs)
class PruneEventTestCase(unittest.TestCase): class PruneEventTestCase(unittest.TestCase):
@ -114,3 +118,167 @@ class PruneEventTestCase(unittest.TestCase):
'unsigned': {}, 'unsigned': {},
} }
) )
class SerializeEventTestCase(unittest.TestCase):
def serialize(self, ev, fields):
return serialize_event(ev, 1479807801915, only_event_fields=fields)
def test_event_fields_works_with_keys(self):
self.assertEquals(
self.serialize(
MockEvent(
sender="@alice:localhost",
room_id="!foo:bar"
),
["room_id"]
),
{
"room_id": "!foo:bar",
}
)
def test_event_fields_works_with_nested_keys(self):
self.assertEquals(
self.serialize(
MockEvent(
sender="@alice:localhost",
room_id="!foo:bar",
content={
"body": "A message",
},
),
["content.body"]
),
{
"content": {
"body": "A message",
}
}
)
def test_event_fields_works_with_dot_keys(self):
self.assertEquals(
self.serialize(
MockEvent(
sender="@alice:localhost",
room_id="!foo:bar",
content={
"key.with.dots": {},
},
),
["content.key\.with\.dots"]
),
{
"content": {
"key.with.dots": {},
}
}
)
def test_event_fields_works_with_nested_dot_keys(self):
self.assertEquals(
self.serialize(
MockEvent(
sender="@alice:localhost",
room_id="!foo:bar",
content={
"not_me": 1,
"nested.dot.key": {
"leaf.key": 42,
"not_me_either": 1,
},
},
),
["content.nested\.dot\.key.leaf\.key"]
),
{
"content": {
"nested.dot.key": {
"leaf.key": 42,
},
}
}
)
def test_event_fields_nops_with_unknown_keys(self):
self.assertEquals(
self.serialize(
MockEvent(
sender="@alice:localhost",
room_id="!foo:bar",
content={
"foo": "bar",
},
),
["content.foo", "content.notexists"]
),
{
"content": {
"foo": "bar",
}
}
)
def test_event_fields_nops_with_non_dict_keys(self):
self.assertEquals(
self.serialize(
MockEvent(
sender="@alice:localhost",
room_id="!foo:bar",
content={
"foo": ["I", "am", "an", "array"],
},
),
["content.foo.am"]
),
{}
)
def test_event_fields_nops_with_array_keys(self):
self.assertEquals(
self.serialize(
MockEvent(
sender="@alice:localhost",
room_id="!foo:bar",
content={
"foo": ["I", "am", "an", "array"],
},
),
["content.foo.1"]
),
{}
)
def test_event_fields_all_fields_if_empty(self):
self.assertEquals(
self.serialize(
MockEvent(
room_id="!foo:bar",
content={
"foo": "bar",
},
),
[]
),
{
"room_id": "!foo:bar",
"content": {
"foo": "bar",
},
"unsigned": {}
}
)
def test_event_fields_fail_if_fields_not_str(self):
with self.assertRaises(TypeError):
self.serialize(
MockEvent(
room_id="!foo:bar",
content={
"foo": "bar",
},
),
["room_id", 4]
)

View File

@ -61,14 +61,14 @@ class AuthTestCase(unittest.TestCase):
def verify_type(caveat): def verify_type(caveat):
return caveat == "type = access" return caveat == "type = access"
def verify_expiry(caveat): def verify_nonce(caveat):
return caveat == "time < 8600000" return caveat.startswith("nonce =")
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
v.satisfy_general(verify_gen) v.satisfy_general(verify_gen)
v.satisfy_general(verify_user) v.satisfy_general(verify_user)
v.satisfy_general(verify_type) v.satisfy_general(verify_type)
v.satisfy_general(verify_expiry) v.satisfy_general(verify_nonce)
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self.hs.config.macaroon_secret_key)
def test_short_term_login_token_gives_user_id(self): def test_short_term_login_token_gives_user_id(self):

View File

@ -53,13 +53,12 @@ class RegistrationTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self): def test_user_is_created_and_logged_in_if_doesnt_exist(self):
duration_ms = 200
local_part = "someone" local_part = "someone"
display_name = "someone" display_name = "someone"
user_id = "@someone:test" user_id = "@someone:test"
requester = create_requester("@as:test") requester = create_requester("@as:test")
result_user_id, result_token = yield self.handler.get_or_create_user( result_user_id, result_token = yield self.handler.get_or_create_user(
requester, local_part, display_name, duration_ms) requester, local_part, display_name)
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')
@ -71,12 +70,11 @@ class RegistrationTestCase(unittest.TestCase):
user_id=frank.to_string(), user_id=frank.to_string(),
token="jkv;g498752-43gj['eamb!-5", token="jkv;g498752-43gj['eamb!-5",
password_hash=None) password_hash=None)
duration_ms = 200
local_part = "frank" local_part = "frank"
display_name = "Frank" display_name = "Frank"
user_id = "@frank:test" user_id = "@frank:test"
requester = create_requester("@as:test") requester = create_requester("@as:test")
result_user_id, result_token = yield self.handler.get_or_create_user( result_user_id, result_token = yield self.handler.get_or_create_user(
requester, local_part, display_name, duration_ms) requester, local_part, display_name)
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')

View File

@ -103,7 +103,7 @@ class ReplicationResourceCase(unittest.TestCase):
room_id = yield self.create_room() room_id = yield self.create_room()
event_id = yield self.send_text_message(room_id, "Hello, World") event_id = yield self.send_text_message(room_id, "Hello, World")
get = self.get(receipts="-1") get = self.get(receipts="-1")
yield self.hs.get_handlers().receipts_handler.received_client_receipt( yield self.hs.get_receipts_handler().received_client_receipt(
room_id, "m.read", self.user_id, event_id room_id, "m.read", self.user_id, event_id
) )
code, body = yield get code, body = yield get

View File

@ -67,8 +67,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.registration_handler.appservice_register = Mock( self.registration_handler.appservice_register = Mock(
return_value=user_id return_value=user_id
) )
self.auth_handler.get_login_tuple_for_user_id = Mock( self.auth_handler.get_access_token_for_user_id = Mock(
return_value=(token, "kermits_refresh_token") return_value=token
) )
(code, result) = yield self.servlet.on_POST(self.request) (code, result) = yield self.servlet.on_POST(self.request)
@ -76,11 +76,9 @@ class RegisterRestServletTestCase(unittest.TestCase):
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname "home_server": self.hs.hostname
} }
self.assertDictContainsSubset(det_data, result) self.assertDictContainsSubset(det_data, result)
self.assertIn("refresh_token", result)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_POST_appservice_registration_invalid(self): def test_POST_appservice_registration_invalid(self):
@ -126,8 +124,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
"password": "monkey" "password": "monkey"
}, None) }, None)
self.registration_handler.register = Mock(return_value=(user_id, None)) self.registration_handler.register = Mock(return_value=(user_id, None))
self.auth_handler.get_login_tuple_for_user_id = Mock( self.auth_handler.get_access_token_for_user_id = Mock(
return_value=(token, "kermits_refresh_token") return_value=token
) )
self.device_handler.check_device_registered = \ self.device_handler.check_device_registered = \
Mock(return_value=device_id) Mock(return_value=device_id)
@ -137,12 +135,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
self.assertDictContainsSubset(det_data, result) self.assertDictContainsSubset(det_data, result)
self.assertIn("refresh_token", result)
self.auth_handler.get_login_tuple_for_user_id( self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id, initial_device_display_name=None) user_id, device_id=device_id, initial_device_display_name=None)

View File

@ -39,7 +39,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
event_cache_size=1, event_cache_size=1,
password_providers=[], password_providers=[],
) )
hs = yield setup_test_homeserver(config=config) hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
self.as_token = "token1" self.as_token = "token1"
self.as_url = "some_url" self.as_url = "some_url"
@ -112,7 +112,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
event_cache_size=1, event_cache_size=1,
password_providers=[], password_providers=[],
) )
hs = yield setup_test_homeserver(config=config) hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
self.db_pool = hs.get_db_pool() self.db_pool = hs.get_db_pool()
self.as_list = [ self.as_list = [
@ -443,7 +443,11 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
app_service_config_files=[f1, f2], event_cache_size=1, app_service_config_files=[f1, f2], event_cache_size=1,
password_providers=[] password_providers=[]
) )
hs = yield setup_test_homeserver(config=config, datastore=Mock()) hs = yield setup_test_homeserver(
config=config,
datastore=Mock(),
federation_sender=Mock()
)
ApplicationServiceStore(hs) ApplicationServiceStore(hs)
@ -456,7 +460,11 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
app_service_config_files=[f1, f2], event_cache_size=1, app_service_config_files=[f1, f2], event_cache_size=1,
password_providers=[] password_providers=[]
) )
hs = yield setup_test_homeserver(config=config, datastore=Mock()) hs = yield setup_test_homeserver(
config=config,
datastore=Mock(),
federation_sender=Mock()
)
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(hs) ApplicationServiceStore(hs)
@ -475,7 +483,11 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
app_service_config_files=[f1, f2], event_cache_size=1, app_service_config_files=[f1, f2], event_cache_size=1,
password_providers=[] password_providers=[]
) )
hs = yield setup_test_homeserver(config=config, datastore=Mock()) hs = yield setup_test_homeserver(
config=config,
datastore=Mock(),
federation_sender=Mock()
)
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(hs) ApplicationServiceStore(hs)

Some files were not shown because too many files have changed in this diff Show More