Merge branch 'develop' into erikj/groups_merged

This commit is contained in:
David Baker 2017-10-02 16:20:41 +01:00
commit 27955056e0
31 changed files with 1181 additions and 303 deletions

View File

@ -1,3 +1,43 @@
Changes in synapse v0.23.0-rc2 (2017-09-26)
===========================================
Bug fixes:
* Fix regression in performance of syncs (PR #2470)
Changes in synapse v0.23.0-rc1 (2017-09-25)
===========================================
Features:
* Add a frontend proxy worker (PR #2344)
* Add support for event_id_only push format (PR #2450)
* Add a PoC for filtering spammy events (PR #2456)
* Add a config option to block all room invites (PR #2457)
Changes:
* Use bcrypt module instead of py-bcrypt (PR #2288) Thanks to @kyrias!
* Improve performance of generating push notifications (PR #2343, #2357, #2365,
#2366, #2371)
* Improve DB performance for device list handling in sync (PR #2362)
* Include a sample prometheus config (PR #2416)
* Document known to work postgres version (PR #2433) Thanks to @ptman!
Bug fixes:
* Fix caching error in the push evaluator (PR #2332)
* Fix bug where pusherpool didn't start and broke some rooms (PR #2342)
* Fix port script for user directory tables (PR #2375)
* Fix device lists notifications when user rejoins a room (PR #2443, #2449)
* Fix sync to always send down current state events in timeline (PR #2451)
* Fix bug where guest users were incorrectly kicked (PR #2453)
* Fix bug talking to IPv6 only servers using SRV records (PR #2462)
Changes in synapse v0.22.1 (2017-07-06) Changes in synapse v0.22.1 (2017-07-06)
======================================= =======================================

View File

@ -200,11 +200,11 @@ different. See `the spec`__ for more information on key management.)
.. __: `key_management`_ .. __: `key_management`_
The default configuration exposes two HTTP ports: 8008 and 8448. Port 8008 is The default configuration exposes two HTTP ports: 8008 and 8448. Port 8008 is
configured without TLS; it is not recommended this be exposed outside your configured without TLS; it should be behind a reverse proxy for TLS/SSL
local network. Port 8448 is configured to use TLS with a self-signed termination on port 443 which in turn should be used for clients. Port 8448
certificate. This is fine for testing with but, to avoid your clients is configured to use TLS with a self-signed certificate. If you would like
complaining about the certificate, you will almost certainly want to use to do initial test with a client without having to setup a reverse proxy,
another certificate for production purposes. (Note that a self-signed you can temporarly use another certificate. (Note that a self-signed
certificate is fine for `Federation`_). You can do so by changing certificate is fine for `Federation`_). You can do so by changing
``tls_certificate_path``, ``tls_private_key_path`` and ``tls_dh_params_path`` ``tls_certificate_path``, ``tls_private_key_path`` and ``tls_dh_params_path``
in ``homeserver.yaml``; alternatively, you can use a reverse-proxy, but be sure in ``homeserver.yaml``; alternatively, you can use a reverse-proxy, but be sure
@ -283,10 +283,16 @@ Connecting to Synapse from a client
The easiest way to try out your new Synapse installation is by connecting to it The easiest way to try out your new Synapse installation is by connecting to it
from a web client. The easiest option is probably the one at from a web client. The easiest option is probably the one at
http://riot.im/app. You will need to specify a "Custom server" when you log on http://riot.im/app. You will need to specify a "Custom server" when you log on
or register: set this to ``https://localhost:8448`` - remember to specify the or register: set this to ``https://domain.tld`` if you setup a reverse proxy
port (``:8448``) unless you changed the configuration. (Leave the identity following the recommended setup, or ``https://localhost:8448`` - remember to specify the
port (``:8448``) if not ``:443`` unless you changed the configuration. (Leave the identity
server as the default - see `Identity servers`_.) server as the default - see `Identity servers`_.)
If using port 8448 you will run into errors until you accept the self-signed
certificate. You can easily do this by going to ``https://localhost:8448``
directly with your browser and accept the presented certificate. You can then
go back in your web client and proceed further.
If all goes well you should at least be able to log in, create a room, and If all goes well you should at least be able to log in, create a room, and
start sending messages. start sending messages.
@ -593,8 +599,9 @@ you to run your server on a machine that might not have the same name as your
domain name. For example, you might want to run your server at domain name. For example, you might want to run your server at
``synapse.example.com``, but have your Matrix user-ids look like ``synapse.example.com``, but have your Matrix user-ids look like
``@user:example.com``. (A SRV record also allows you to change the port from ``@user:example.com``. (A SRV record also allows you to change the port from
the default 8448. However, if you are thinking of using a reverse-proxy, be the default 8448. However, if you are thinking of using a reverse-proxy on the
sure to read `Reverse-proxying the federation port`_ first.) federation port, which is not recommended, be sure to read
`Reverse-proxying the federation port`_ first.)
To use a SRV record, first create your SRV record and publish it in DNS. This To use a SRV record, first create your SRV record and publish it in DNS. This
should have the format ``_matrix._tcp.<yourdomain.com> <ttl> IN SRV 10 0 <port> should have the format ``_matrix._tcp.<yourdomain.com> <ttl> IN SRV 10 0 <port>
@ -674,7 +681,7 @@ For information on how to install and use PostgreSQL, please see
Using a reverse proxy with Synapse Using a reverse proxy with Synapse
================================== ==================================
It is possible to put a reverse proxy such as It is recommended to put a reverse proxy such as
`nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_, `nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_,
`Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_ or `Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_ or
`HAProxy <http://www.haproxy.org/>`_ in front of Synapse. One advantage of `HAProxy <http://www.haproxy.org/>`_ in front of Synapse. One advantage of
@ -692,9 +699,9 @@ federation port has a number of pitfalls. It is possible, but be sure to read
`Reverse-proxying the federation port`_. `Reverse-proxying the federation port`_.
The recommended setup is therefore to configure your reverse-proxy on port 443 The recommended setup is therefore to configure your reverse-proxy on port 443
for client connections, but to also expose port 8448 for server-server to port 8008 of synapse for client connections, but to also directly expose port
connections. All the Matrix endpoints begin ``/_matrix``, so an example nginx 8448 for server-server connections. All the Matrix endpoints begin ``/_matrix``,
configuration might look like:: so an example nginx configuration might look like::
server { server {
listen 443 ssl; listen 443 ssl;

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.22.1" __version__ = "0.23.0-rc2"

View File

@ -519,6 +519,14 @@ class Auth(object):
) )
def is_server_admin(self, user): def is_server_admin(self, user):
""" Check if the given user is a local server admin.
Args:
user (str): mxid of user to check
Returns:
bool: True if the user is an admin
"""
return self.store.is_server_admin(user) return self.store.is_server_admin(user)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -34,6 +34,7 @@ from .password_auth_providers import PasswordAuthProviderConfig
from .emailconfig import EmailConfig from .emailconfig import EmailConfig
from .workers import WorkerConfig from .workers import WorkerConfig
from .push import PushConfig from .push import PushConfig
from .spam_checker import SpamCheckerConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
@ -41,7 +42,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, PasswordConfig, EmailConfig, JWTConfig, PasswordConfig, EmailConfig,
WorkerConfig, PasswordAuthProviderConfig, PushConfig,): WorkerConfig, PasswordAuthProviderConfig, PushConfig,
SpamCheckerConfig,):
pass pass

View File

@ -15,13 +15,15 @@
from ._base import Config, ConfigError from ._base import Config, ConfigError
import importlib from synapse.util.module_loader import load_module
class PasswordAuthProviderConfig(Config): class PasswordAuthProviderConfig(Config):
def read_config(self, config): def read_config(self, config):
self.password_providers = [] self.password_providers = []
provider_config = None
# We want to be backwards compatible with the old `ldap_config` # We want to be backwards compatible with the old `ldap_config`
# param. # param.
ldap_config = config.get("ldap_config", {}) ldap_config = config.get("ldap_config", {})
@ -38,19 +40,15 @@ class PasswordAuthProviderConfig(Config):
if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider": if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
from ldap_auth_provider import LdapAuthProvider from ldap_auth_provider import LdapAuthProvider
provider_class = LdapAuthProvider provider_class = LdapAuthProvider
else:
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
module, clz = provider['module'].rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
try: try:
provider_config = provider_class.parse_config(provider["config"]) provider_config = provider_class.parse_config(provider["config"])
except Exception as e: except Exception as e:
raise ConfigError( raise ConfigError(
"Failed to parse config for %r: %r" % (provider['module'], e) "Failed to parse config for %r: %r" % (provider['module'], e)
) )
else:
(provider_class, provider_config) = load_module(provider)
self.password_providers.append((provider_class, provider_config)) self.password_providers.append((provider_class, provider_config))
def default_config(self, **kwargs): def default_config(self, **kwargs):

View File

@ -43,6 +43,12 @@ class ServerConfig(Config):
self.filter_timeline_limit = config.get("filter_timeline_limit", -1) self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
# Whether we should block invites sent to users on this server
# (other than those sent by local server admins)
self.block_non_admin_invites = config.get(
"block_non_admin_invites", False,
)
if self.public_baseurl is not None: if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/': if self.public_baseurl[-1] != '/':
self.public_baseurl += '/' self.public_baseurl += '/'
@ -194,6 +200,10 @@ class ServerConfig(Config):
# and sync operations. The default value is -1, means no upper limit. # and sync operations. The default value is -1, means no upper limit.
# filter_timeline_limit: 5000 # filter_timeline_limit: 5000
# Whether room invites to users on this server should be blocked
# (except those sent by local server admins). The default is False.
# block_non_admin_invites: True
# List of ports that Synapse should listen on, their purpose and their # List of ports that Synapse should listen on, their purpose and their
# configuration. # configuration.
listeners: listeners:

View File

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.module_loader import load_module
from ._base import Config
class SpamCheckerConfig(Config):
def read_config(self, config):
self.spam_checker = None
provider = config.get("spam_checker", None)
if provider is not None:
self.spam_checker = load_module(provider)
def default_config(self, **kwargs):
return """\
# spam_checker:
# module: "my_custom_project.SuperSpamChecker"
# config:
# example_option: 'things'
"""

View File

@ -13,14 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util import logcontext
from twisted.web.http import HTTPClient from twisted.web.http import HTTPClient
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.logcontext import (
preserve_context_over_fn, preserve_context_over_deferred
)
import simplejson as json import simplejson as json
import logging import logging
@ -43,14 +40,10 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
for i in range(5): for i in range(5):
try: try:
protocol = yield preserve_context_over_fn( with logcontext.PreserveLoggingContext():
endpoint.connect, factory protocol = yield endpoint.connect(factory)
) server_response, server_certificate = yield protocol.remote_key
server_response, server_certificate = yield preserve_context_over_deferred(
protocol.remote_key
)
defer.returnValue((server_response, server_certificate)) defer.returnValue((server_response, server_certificate))
return
except SynapseKeyClientError as e: except SynapseKeyClientError as e:
logger.exception("Error getting key for %r" % (server_name,)) logger.exception("Error getting key for %r" % (server_name,))
if e.status.startswith("4"): if e.status.startswith("4"):

View File

@ -18,7 +18,7 @@ from synapse.crypto.keyclient import fetch_server_key
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.util import unwrapFirstError, logcontext from synapse.util import unwrapFirstError, logcontext
from synapse.util.logcontext import ( from synapse.util.logcontext import (
preserve_context_over_fn, PreserveLoggingContext, PreserveLoggingContext,
preserve_fn preserve_fn
) )
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -57,7 +57,8 @@ Attributes:
json_object(dict): The JSON object to verify. json_object(dict): The JSON object to verify.
deferred(twisted.internet.defer.Deferred): deferred(twisted.internet.defer.Deferred):
A deferred (server_name, key_id, verify_key) tuple that resolves when A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched a verify key has been fetched. The deferreds' callbacks are run with no
logcontext.
""" """
@ -82,9 +83,11 @@ class Keyring(object):
self.key_downloads = {} self.key_downloads = {}
def verify_json_for_server(self, server_name, json_object): def verify_json_for_server(self, server_name, json_object):
return self.verify_json_objects_for_server( return logcontext.make_deferred_yieldable(
self.verify_json_objects_for_server(
[(server_name, json_object)] [(server_name, json_object)]
)[0] )[0]
)
def verify_json_objects_for_server(self, server_and_json): def verify_json_objects_for_server(self, server_and_json):
"""Bulk verifies signatures of json objects, bulk fetching keys as """Bulk verifies signatures of json objects, bulk fetching keys as
@ -94,8 +97,10 @@ class Keyring(object):
server_and_json (list): List of pairs of (server_name, json_object) server_and_json (list): List of pairs of (server_name, json_object)
Returns: Returns:
list of deferreds indicating success or failure to verify each List<Deferred>: for each input pair, a deferred indicating success
json object's signature for the given server_name. or failure to verify each json object's signature for the given
server_name. The deferreds run their callbacks in the sentinel
logcontext.
""" """
verify_requests = [] verify_requests = []
@ -122,73 +127,59 @@ class Keyring(object):
verify_requests.append(verify_request) verify_requests.append(verify_request)
preserve_fn(self._start_key_lookups)(verify_requests)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
handle = preserve_fn(_handle_key_deferred)
return [
handle(rq) for rq in verify_requests
]
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_key_deferred(verify_request): def _start_key_lookups(self, verify_requests):
server_name = verify_request.server_name """Sets off the key fetches for each verify request
try:
_, key_id, verify_key = yield verify_request.deferred
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
502,
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
)
except Exception as e:
logger.exception(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, key_ids),
Codes.UNAUTHORIZED,
)
json_object = verify_request.json_object Once each fetch completes, verify_request.deferred will be resolved.
logger.debug("Got key %s %s:%s for server %s, verifying" % ( Args:
key_id, verify_key.alg, verify_key.version, server_name, verify_requests (List[VerifyKeyRequest]):
)) """
try:
verify_signed_json(json_object, server_name, verify_key)
except:
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s" % (
server_name, verify_key.alg, verify_key.version
),
Codes.UNAUTHORIZED,
)
# create a deferred for each server we're going to look up the keys
# for; we'll resolve them once we have completed our lookups.
# These will be passed into wait_for_previous_lookups to block
# any other lookups until we have finished.
# The deferreds are called with no logcontext.
server_to_deferred = { server_to_deferred = {
server_name: defer.Deferred() rq.server_name: defer.Deferred()
for server_name, _ in server_and_json for rq in verify_requests
} }
with PreserveLoggingContext():
# We want to wait for any previous lookups to complete before # We want to wait for any previous lookups to complete before
# proceeding. # proceeding.
wait_on_deferred = self.wait_for_previous_lookups( yield self.wait_for_previous_lookups(
[server_name for server_name, _ in server_and_json], [rq.server_name for rq in verify_requests],
server_to_deferred, server_to_deferred,
) )
# Actually start fetching keys. # Actually start fetching keys.
wait_on_deferred.addBoth( self._get_server_verify_keys(verify_requests)
lambda _: self.get_server_verify_keys(verify_requests)
)
# When we've finished fetching all the keys for a given server_name, # When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that # resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed. # any lookups waiting will proceed.
#
# map from server name to a set of request ids
server_to_request_ids = {} server_to_request_ids = {}
def remove_deferreds(res, server_name, verify_request): for verify_request in verify_requests:
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id)
def remove_deferreds(res, verify_request):
server_name = verify_request.server_name
request_id = id(verify_request) request_id = id(verify_request)
server_to_request_ids[server_name].discard(request_id) server_to_request_ids[server_name].discard(request_id)
if not server_to_request_ids[server_name]: if not server_to_request_ids[server_name]:
@ -198,17 +189,9 @@ class Keyring(object):
return res return res
for verify_request in verify_requests: for verify_request in verify_requests:
server_name = verify_request.server_name verify_request.deferred.addBoth(
request_id = id(verify_request) remove_deferreds, verify_request,
server_to_request_ids.setdefault(server_name, set()).add(request_id) )
deferred.addBoth(remove_deferreds, server_name, verify_request)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
return [
preserve_context_over_fn(handle_key_deferred, verify_request)
for verify_request in verify_requests
]
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_previous_lookups(self, server_names, server_to_deferred): def wait_for_previous_lookups(self, server_names, server_to_deferred):
@ -245,7 +228,7 @@ class Keyring(object):
self.key_downloads[server_name] = deferred self.key_downloads[server_name] = deferred
deferred.addBoth(rm, server_name) deferred.addBoth(rm, server_name)
def get_server_verify_keys(self, verify_requests): def _get_server_verify_keys(self, verify_requests):
"""Tries to find at least one key for each verify request """Tries to find at least one key for each verify request
For each verify_request, verify_request.deferred is called back with For each verify_request, verify_request.deferred is called back with
@ -314,7 +297,8 @@ class Keyring(object):
if not missing_keys: if not missing_keys:
break break
for verify_request in requests_missing_keys.values(): with PreserveLoggingContext():
for verify_request in requests_missing_keys:
verify_request.deferred.errback(SynapseError( verify_request.deferred.errback(SynapseError(
401, 401,
"No key for %s with id %s" % ( "No key for %s with id %s" % (
@ -324,11 +308,12 @@ class Keyring(object):
)) ))
def on_err(err): def on_err(err):
with PreserveLoggingContext():
for verify_request in verify_requests: for verify_request in verify_requests:
if not verify_request.deferred.called: if not verify_request.deferred.called:
verify_request.deferred.errback(err) verify_request.deferred.errback(err)
do_iterations().addErrback(on_err) preserve_fn(do_iterations)().addErrback(on_err)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids): def get_keys_from_store(self, server_name_and_key_ids):
@ -738,3 +723,47 @@ class Keyring(object):
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError)) ).addErrback(unwrapFirstError))
@defer.inlineCallbacks
def _handle_key_deferred(verify_request):
server_name = verify_request.server_name
try:
with PreserveLoggingContext():
_, key_id, verify_key = yield verify_request.deferred
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
502,
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
)
except Exception as e:
logger.exception(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, verify_request.key_ids),
Codes.UNAUTHORIZED,
)
json_object = verify_request.json_object
logger.debug("Got key %s %s:%s for server %s, verifying" % (
key_id, verify_key.alg, verify_key.version, server_name,
))
try:
verify_signed_json(json_object, server_name, verify_key)
except:
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s" % (
server_name, verify_key.alg, verify_key.version
),
Codes.UNAUTHORIZED,
)

View File

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class SpamChecker(object):
def __init__(self, hs):
self.spam_checker = None
module = None
config = None
try:
module, config = hs.config.spam_checker
except:
pass
if module is not None:
self.spam_checker = module(config=config)
def check_event_for_spam(self, event):
"""Checks if a given event is considered "spammy" by this server.
If the server considers an event spammy, then it will be rejected if
sent by a local user. If it is sent by a user on another server, then
users receive a blank event.
Args:
event (synapse.events.EventBase): the event to be checked
Returns:
bool: True if the event is spammy.
"""
if self.spam_checker is None:
return False
return self.spam_checker.check_event_for_spam(event)

View File

@ -12,28 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.events.utils import prune_event
from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError
from synapse.util import unwrapFirstError
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging import logging
from synapse.api.errors import SynapseError
from synapse.crypto.event_signing import check_event_content_hash
from synapse.events.utils import prune_event
from synapse.util import unwrapFirstError, logcontext
from twisted.internet import defer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FederationBase(object): class FederationBase(object):
def __init__(self, hs): def __init__(self, hs):
pass self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False, def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
@ -57,56 +49,52 @@ class FederationBase(object):
""" """
deferreds = self._check_sigs_and_hashes(pdus) deferreds = self._check_sigs_and_hashes(pdus)
def callback(pdu): @defer.inlineCallbacks
return pdu def handle_check_result(pdu, deferred):
try:
res = yield logcontext.make_deferred_yieldable(deferred)
except SynapseError:
res = None
def errback(failure, pdu):
failure.trap(SynapseError)
return None
def try_local_db(res, pdu):
if not res: if not res:
# Check local db. # Check local db.
return self.store.get_event( res = yield self.store.get_event(
pdu.event_id, pdu.event_id,
allow_rejected=True, allow_rejected=True,
allow_none=True, allow_none=True,
) )
return res
def try_remote(res, pdu):
if not res and pdu.origin != origin: if not res and pdu.origin != origin:
return self.get_pdu( try:
res = yield self.get_pdu(
destinations=[pdu.origin], destinations=[pdu.origin],
event_id=pdu.event_id, event_id=pdu.event_id,
outlier=outlier, outlier=outlier,
timeout=10000, timeout=10000,
).addErrback(lambda e: None) )
return res except SynapseError:
pass
def warn(res, pdu):
if not res: if not res:
logger.warn( logger.warn(
"Failed to find copy of %s with valid signature", "Failed to find copy of %s with valid signature",
pdu.event_id, pdu.event_id,
) )
return res
for pdu, deferred in zip(pdus, deferreds): defer.returnValue(res)
deferred.addCallbacks(
callback, errback, errbackArgs=[pdu] handle = logcontext.preserve_fn(handle_check_result)
).addCallback( deferreds2 = [
try_local_db, pdu handle(pdu, deferred)
).addCallback( for pdu, deferred in zip(pdus, deferreds)
try_remote, pdu ]
).addCallback(
warn, pdu valid_pdus = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
deferreds2,
consumeErrors=True,
) )
).addErrback(unwrapFirstError)
valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
deferreds,
consumeErrors=True
)).addErrback(unwrapFirstError)
if include_none: if include_none:
defer.returnValue(valid_pdus) defer.returnValue(valid_pdus)
@ -114,15 +102,24 @@ class FederationBase(object):
defer.returnValue([p for p in valid_pdus if p]) defer.returnValue([p for p in valid_pdus if p])
def _check_sigs_and_hash(self, pdu): def _check_sigs_and_hash(self, pdu):
return self._check_sigs_and_hashes([pdu])[0] return logcontext.make_deferred_yieldable(
self._check_sigs_and_hashes([pdu])[0],
)
def _check_sigs_and_hashes(self, pdus): def _check_sigs_and_hashes(self, pdus):
"""Throws a SynapseError if a PDU does not have the correct """Checks that each of the received events is correctly signed by the
signatures. sending server.
Args:
pdus (list[FrozenEvent]): the events to be checked
Returns: Returns:
FrozenEvent: Either the given event or it redacted if it failed the list[Deferred]: for each input event, a deferred which:
content hash check. * returns the original event if the checks pass
* returns a redacted version of the event (if the signature
matched but the hash did not)
* throws a SynapseError if the signature check failed.
The deferreds run their callbacks in the sentinel logcontext.
""" """
redacted_pdus = [ redacted_pdus = [
@ -130,22 +127,34 @@ class FederationBase(object):
for pdu in pdus for pdu in pdus
] ]
deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([ deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json()) (p.origin, p.get_pdu_json())
for p in redacted_pdus for p in redacted_pdus
]) ])
ctx = logcontext.LoggingContext.current_context()
def callback(_, pdu, redacted): def callback(_, pdu, redacted):
with logcontext.PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu): if not check_event_content_hash(pdu):
logger.warn( logger.warn(
"Event content has been tampered, redacting %s: %s", "Event content has been tampered, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json() pdu.event_id, pdu.get_pdu_json()
) )
return redacted return redacted
if self.spam_checker.check_event_for_spam(pdu):
logger.warn(
"Event contains spam, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json()
)
return redacted
return pdu return pdu
def errback(failure, pdu): def errback(failure, pdu):
failure.trap(SynapseError) failure.trap(SynapseError)
with logcontext.PreserveLoggingContext(ctx):
logger.warn( logger.warn(
"Signature check failed for %s", "Signature check failed for %s",
pdu.event_id, pdu.event_id,

View File

@ -22,7 +22,7 @@ from synapse.api.constants import Membership
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError, CodeMessageException, HttpResponseException, SynapseError,
) )
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@ -189,10 +189,10 @@ class FederationClient(FederationBase):
] ]
# FIXME: We should handle signature failures more gracefully. # FIXME: We should handle signature failures more gracefully.
pdus[:] = yield preserve_context_over_deferred(defer.gatherResults( pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
self._check_sigs_and_hashes(pdus), self._check_sigs_and_hashes(pdus),
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError))
defer.returnValue(pdus) defer.returnValue(pdus)
@ -252,7 +252,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0] pdu = pdu_list[0]
# Check signatures are correct. # Check signatures are correct.
signed_pdu = yield self._check_sigs_and_hashes([pdu])[0] signed_pdu = yield self._check_sigs_and_hash(pdu)
break break

View File

@ -1074,6 +1074,9 @@ class FederationHandler(BaseHandler):
if is_blocked: if is_blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
membership = event.content.get("membership") membership = event.content.get("membership")
if event.type != EventTypes.Member or membership != Membership.INVITE: if event.type != EventTypes.Member or membership != Membership.INVITE:
raise SynapseError(400, "The event was not an m.room.member invite event") raise SynapseError(400, "The event was not an m.room.member invite event")
@ -2090,6 +2093,14 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict): def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
"""Handle an exchange_third_party_invite request from a remote server
The remote server will call this when it wants to turn a 3pid invite
into a normal m.room.member invite.
Returns:
Deferred: resolves (to None)
"""
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
message_handler = self.hs.get_handlers().message_handler message_handler = self.hs.get_handlers().message_handler
@ -2108,9 +2119,12 @@ class FederationHandler(BaseHandler):
raise e raise e
yield self._check_signature(event, context) yield self._check_signature(event, context)
# XXX we send the invite here, but send_membership_event also sends it,
# so we end up making two requests. I think this is redundant.
returned_invite = yield self.send_invite(origin, event) returned_invite = yield self.send_invite(origin, event)
# TODO: Make sure the signatures actually are correct. # TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures) event.signatures.update(returned_invite.signatures)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(None, event, context) yield member_handler.send_membership_event(None, event, context)

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
@ -59,6 +58,8 @@ class MessageHandler(BaseHandler):
self.action_generator = hs.get_action_generator() self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks @defer.inlineCallbacks
def purge_history(self, room_id, event_id): def purge_history(self, room_id, event_id):
event = yield self.store.get_event(event_id) event = yield self.store.get_event(event_id)
@ -322,6 +323,12 @@ class MessageHandler(BaseHandler):
token_id=requester.access_token_id, token_id=requester.access_token_id,
txn_id=txn_id txn_id=txn_id
) )
if self.spam_checker.check_event_for_spam(event):
raise SynapseError(
403, "Spam is not permitted here", Codes.FORBIDDEN
)
yield self.send_nonmember_event( yield self.send_nonmember_event(
requester, requester,
event, event,
@ -413,6 +420,51 @@ class MessageHandler(BaseHandler):
[serialize_event(c, now) for c in room_state.values()] [serialize_event(c, now) for c in room_state.values()]
) )
@defer.inlineCallbacks
def get_joined_members(self, requester, room_id):
"""Get all the joined members in the room and their profile information.
If the user has left the room return the state events from when they left.
Args:
requester(Requester): The user requesting state events.
room_id(str): The room ID to get all state events from.
Returns:
A dict of user_id to profile info
"""
user_id = requester.user.to_string()
if not requester.app_service:
# We check AS auth after fetching the room membership, as it
# requires us to pull out all joined members anyway.
membership, _ = yield self._check_in_room_or_world_readable(
room_id, user_id
)
if membership != Membership.JOIN:
raise NotImplementedError(
"Getting joined members after leaving is not implemented"
)
users_with_profile = yield self.state.get_current_user_in_room(room_id)
# If this is an AS, double check that they are allowed to see the members.
# This can either be because the AS user is in the room or becuase there
# is a user in the room that the AS is "interested in"
if requester.app_service and user_id not in users_with_profile:
for uid in users_with_profile:
if requester.app_service.is_interested_in_user(uid):
break
else:
# Loop fell through, AS has no interested users in room
raise AuthError(403, "Appservice not in room")
defer.returnValue({
user_id: {
"avatar_url": profile.avatar_url,
"display_name": profile.display_name,
}
for user_id, profile in users_with_profile.iteritems()
})
@measure_func("_create_new_client_event") @measure_func("_create_new_client_event")
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder, requester=None, prev_event_ids=None): def _create_new_client_event(self, builder, requester=None, prev_event_ids=None):

View File

@ -193,6 +193,8 @@ class RoomMemberHandler(BaseHandler):
if action in ["kick", "unban"]: if action in ["kick", "unban"]:
effective_membership_state = "leave" effective_membership_state = "leave"
# if this is a join with a 3pid signature, we may need to turn a 3pid
# invite into a normal invite before we can handle the join.
if third_party_signed is not None: if third_party_signed is not None:
replication = self.hs.get_replication_layer() replication = self.hs.get_replication_layer()
yield replication.exchange_third_party_invite( yield replication.exchange_third_party_invite(
@ -210,6 +212,16 @@ class RoomMemberHandler(BaseHandler):
if is_blocked: if is_blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
if (effective_membership_state == "invite" and
self.hs.config.block_non_admin_invites):
is_requester_admin = yield self.auth.is_server_admin(
requester.user,
)
if not is_requester_admin:
raise SynapseError(
403, "Invites have been disabled on this server",
)
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
current_state_ids = yield self.state_handler.get_current_state_ids( current_state_ids = yield self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids, room_id, latest_event_ids=latest_event_ids,
@ -473,6 +485,16 @@ class RoomMemberHandler(BaseHandler):
requester, requester,
txn_id txn_id
): ):
if self.hs.config.block_non_admin_invites:
is_requester_admin = yield self.auth.is_server_admin(
requester.user,
)
if not is_requester_admin:
raise SynapseError(
403, "Invites have been disabled on this server",
Codes.FORBIDDEN,
)
invitee = yield self._lookup_3pid( invitee = yield self._lookup_3pid(
id_server, medium, address id_server, medium, address
) )

View File

@ -306,11 +306,6 @@ class SyncHandler(object):
timeline_limit = sync_config.filter_collection.timeline_limit() timeline_limit = sync_config.filter_collection.timeline_limit()
block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline() block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline()
# Pull out the current state, as we always want to include those events
# in the timeline if they're there.
current_state_ids = yield self.state.get_current_state_ids(room_id)
current_state_ids = frozenset(current_state_ids.itervalues())
if recents is None or newly_joined_room or timeline_limit < len(recents): if recents is None or newly_joined_room or timeline_limit < len(recents):
limited = True limited = True
else: else:
@ -318,6 +313,15 @@ class SyncHandler(object):
if recents: if recents:
recents = sync_config.filter_collection.filter_room_timeline(recents) recents = sync_config.filter_collection.filter_room_timeline(recents)
# We check if there are any state events, if there are then we pass
# all current state events to the filter_events function. This is to
# ensure that we always include current state in the timeline
current_state_ids = frozenset()
if any(e.is_state() for e in recents):
current_state_ids = yield self.state.get_current_state_ids(room_id)
current_state_ids = frozenset(current_state_ids.itervalues())
recents = yield filter_events_for_client( recents = yield filter_events_for_client(
self.store, self.store,
sync_config.user.to_string(), sync_config.user.to_string(),
@ -354,6 +358,15 @@ class SyncHandler(object):
loaded_recents = sync_config.filter_collection.filter_room_timeline( loaded_recents = sync_config.filter_collection.filter_room_timeline(
events events
) )
# We check if there are any state events, if there are then we pass
# all current state events to the filter_events function. This is to
# ensure that we always include current state in the timeline
current_state_ids = frozenset()
if any(e.is_state() for e in loaded_recents):
current_state_ids = yield self.state.get_current_state_ids(room_id)
current_state_ids = frozenset(current_state_ids.itervalues())
loaded_recents = yield filter_events_for_client( loaded_recents = yield filter_events_for_client(
self.store, self.store,
sync_config.user.to_string(), sync_config.user.to_string(),
@ -1042,7 +1055,18 @@ class SyncHandler(object):
# We want to figure out if we joined the room at some point since # We want to figure out if we joined the room at some point since
# the last sync (even if we have since left). This is to make sure # the last sync (even if we have since left). This is to make sure
# we do send down the room, and with full state, where necessary # we do send down the room, and with full state, where necessary
old_state_ids = None old_state_ids = None
if room_id in joined_room_ids and non_joins:
# Always include if the user (re)joined the room, especially
# important so that device list changes are calculated correctly.
# If there are non join member events, but we are still in the room,
# then the user must have left and joined
newly_joined_rooms.append(room_id)
# User is in the room so we don't need to do the invite/leave checks
continue
if room_id in joined_room_ids or has_join: if room_id in joined_room_ids or has_join:
old_state_ids = yield self.get_state_at(room_id, since_token) old_state_ids = yield self.get_state_at(room_id, since_token)
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
@ -1054,6 +1078,7 @@ class SyncHandler(object):
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
newly_joined_rooms.append(room_id) newly_joined_rooms.append(room_id)
# If user is in the room then we don't need to do the invite/leave checks
if room_id in joined_room_ids: if room_id in joined_room_ids:
continue continue

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import socket
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
@ -30,7 +31,10 @@ logger = logging.getLogger(__name__)
SERVER_CACHE = {} SERVER_CACHE = {}
# our record of an individual server which can be tried to reach a destination.
#
# "host" is actually a dotted-quad or ipv6 address string. Except when there's
# no SRV record, in which case it is the original hostname.
_Server = collections.namedtuple( _Server = collections.namedtuple(
"_Server", "priority weight host port expires" "_Server", "priority weight host port expires"
) )
@ -219,9 +223,10 @@ class SRVClientEndpoint(object):
return self.default_server return self.default_server
else: else:
raise ConnectError( raise ConnectError(
"Not server available for %s" % self.service_name "No server available for %s" % self.service_name
) )
# look for all servers with the same priority
min_priority = self.servers[0].priority min_priority = self.servers[0].priority
weight_indexes = list( weight_indexes = list(
(index, server.weight + 1) (index, server.weight + 1)
@ -231,11 +236,22 @@ class SRVClientEndpoint(object):
total_weight = sum(weight for index, weight in weight_indexes) total_weight = sum(weight for index, weight in weight_indexes)
target_weight = random.randint(0, total_weight) target_weight = random.randint(0, total_weight)
for index, weight in weight_indexes: for index, weight in weight_indexes:
target_weight -= weight target_weight -= weight
if target_weight <= 0: if target_weight <= 0:
server = self.servers[index] server = self.servers[index]
# XXX: this looks totally dubious:
#
# (a) we never reuse a server until we have been through
# all of the servers at the same priority, so if the
# weights are A: 100, B:1, we always do ABABAB instead of
# AAAA...AAAB (approximately).
#
# (b) After using all the servers at the lowest priority,
# we move onto the next priority. We should only use the
# second priority if servers at the top priority are
# unreachable.
#
del self.servers[index] del self.servers[index]
self.used_servers.append(server) self.used_servers.append(server)
return server return server
@ -280,18 +296,13 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
continue continue
payload = answer.payload payload = answer.payload
host = str(payload.target)
srv_ttl = answer.ttl
try: hosts = yield _get_hosts_for_srv_record(
answers, _, _ = yield dns_client.lookupAddress(host) dns_client, str(payload.target)
except DNSNameError: )
continue
for answer in answers: for (ip, ttl) in hosts:
if answer.type == dns.A and answer.payload: host_ttl = min(answer.ttl, ttl)
ip = answer.payload.dottedQuad()
host_ttl = min(srv_ttl, answer.ttl)
servers.append(_Server( servers.append(_Server(
host=ip, host=ip,
@ -317,3 +328,80 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
raise e raise e
defer.returnValue(servers) defer.returnValue(servers)
@defer.inlineCallbacks
def _get_hosts_for_srv_record(dns_client, host):
"""Look up each of the hosts in a SRV record
Args:
dns_client (twisted.names.dns.IResolver):
host (basestring): host to look up
Returns:
Deferred[list[(str, int)]]: a list of (host, ttl) pairs
"""
ip4_servers = []
ip6_servers = []
def cb(res):
# lookupAddress and lookupIP6Address return a three-tuple
# giving the answer, authority, and additional sections of the
# response.
#
# we only care about the answers.
return res[0]
def eb(res, record_type):
if res.check(DNSNameError):
return []
logger.warn("Error looking up %s for %s: %s",
record_type, host, res, res.value)
return res
# no logcontexts here, so we can safely fire these off and gatherResults
d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
results = yield defer.DeferredList(
[d1, d2], consumeErrors=True)
# if all of the lookups failed, raise an exception rather than blowing out
# the cache with an empty result.
if results and all(s == defer.FAILURE for (s, _) in results):
defer.returnValue(results[0][1])
for (success, result) in results:
if success == defer.FAILURE:
continue
for answer in result:
if not answer.payload:
continue
try:
if answer.type == dns.A:
ip = answer.payload.dottedQuad()
ip4_servers.append((ip, answer.ttl))
elif answer.type == dns.AAAA:
ip = socket.inet_ntop(
socket.AF_INET6, answer.payload.address,
)
ip6_servers.append((ip, answer.ttl))
else:
# the most likely candidate here is a CNAME record.
# rfc2782 says srvs may not point to aliases.
logger.warn(
"Ignoring unexpected DNS record type %s for %s",
answer.type, host,
)
continue
except Exception as e:
logger.warn("Ignoring invalid DNS response for %s: %s",
host, e)
continue
# keep the ipv4 results before the ipv6 results, mostly to match historical
# behaviour.
defer.returnValue(ip4_servers + ip6_servers)

View File

@ -204,18 +204,15 @@ class MatrixFederationHttpClient(object):
raise raise
logger.warn( logger.warn(
"{%s} Sending request failed to %s: %s %s: %s - %s", "{%s} Sending request failed to %s: %s %s: %s",
txn_id, txn_id,
destination, destination,
method, method,
url_bytes, url_bytes,
type(e).__name__,
_flatten_response_never_received(e), _flatten_response_never_received(e),
) )
log_result = "%s - %s" % ( log_result = _flatten_response_never_received(e)
type(e).__name__, _flatten_response_never_received(e),
)
if retries_left and not timeout: if retries_left and not timeout:
if long_retries: if long_retries:
@ -618,12 +615,14 @@ class _JsonProducer(object):
def _flatten_response_never_received(e): def _flatten_response_never_received(e):
if hasattr(e, "reasons"): if hasattr(e, "reasons"):
return ", ".join( reasons = ", ".join(
_flatten_response_never_received(f.value) _flatten_response_never_received(f.value)
for f in e.reasons for f in e.reasons
) )
return "%s:[%s]" % (type(e).__name__, reasons)
else: else:
return "%s: %s" % (type(e).__name__, e.message,) return repr(e)
def check_content_type_is_json(headers): def check_content_type_is_json(headers):

View File

@ -398,22 +398,18 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(JoinedRoomMemberListRestServlet, self).__init__(hs) super(JoinedRoomMemberListRestServlet, self).__init__(hs)
self.state = hs.get_state_handler() self.message_handler = hs.get_handlers().message_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
users_with_profile = yield self.state.get_current_user_in_room(room_id) users_with_profile = yield self.message_handler.get_joined_members(
requester, room_id,
)
defer.returnValue((200, { defer.returnValue((200, {
"joined": { "joined": users_with_profile,
user_id: {
"avatar_url": profile.avatar_url,
"display_name": profile.display_name,
}
for user_id, profile in users_with_profile.iteritems()
}
})) }))

View File

@ -14,6 +14,9 @@
# limitations under the License. # limitations under the License.
import os import os
import re
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
class MediaFilePaths(object): class MediaFilePaths(object):
@ -73,19 +76,105 @@ class MediaFilePaths(object):
) )
def url_cache_filepath(self, media_id): def url_cache_filepath(self, media_id):
if NEW_FORMAT_ID_RE.match(media_id):
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
return os.path.join( return os.path.join(
self.base_path, "url_cache", self.base_path, "url_cache",
media_id[0:2], media_id[2:4], media_id[4:] media_id[:10], media_id[11:]
) )
else:
return os.path.join(
self.base_path, "url_cache",
media_id[0:2], media_id[2:4], media_id[4:],
)
def url_cache_filepath_dirs_to_delete(self, media_id):
"The dirs to try and remove if we delete the media_id file"
if NEW_FORMAT_ID_RE.match(media_id):
return [
os.path.join(
self.base_path, "url_cache",
media_id[:10],
),
]
else:
return [
os.path.join(
self.base_path, "url_cache",
media_id[0:2], media_id[2:4],
),
os.path.join(
self.base_path, "url_cache",
media_id[0:2],
),
]
def url_cache_thumbnail(self, media_id, width, height, content_type, def url_cache_thumbnail(self, media_id, width, height, content_type,
method): method):
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
top_level_type, sub_type = content_type.split("/") top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % ( file_name = "%i-%i-%s-%s-%s" % (
width, height, top_level_type, sub_type, method width, height, top_level_type, sub_type, method
) )
if NEW_FORMAT_ID_RE.match(media_id):
return os.path.join(
self.base_path, "url_cache_thumbnails",
media_id[:10], media_id[11:],
file_name
)
else:
return os.path.join( return os.path.join(
self.base_path, "url_cache_thumbnails", self.base_path, "url_cache_thumbnails",
media_id[0:2], media_id[2:4], media_id[4:], media_id[0:2], media_id[2:4], media_id[4:],
file_name file_name
) )
def url_cache_thumbnail_directory(self, media_id):
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
if NEW_FORMAT_ID_RE.match(media_id):
return os.path.join(
self.base_path, "url_cache_thumbnails",
media_id[:10], media_id[11:],
)
else:
return os.path.join(
self.base_path, "url_cache_thumbnails",
media_id[0:2], media_id[2:4], media_id[4:],
)
def url_cache_thumbnail_dirs_to_delete(self, media_id):
"The dirs to try and remove if we delete the media_id thumbnails"
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
if NEW_FORMAT_ID_RE.match(media_id):
return [
os.path.join(
self.base_path, "url_cache_thumbnails",
media_id[:10], media_id[11:],
),
os.path.join(
self.base_path, "url_cache_thumbnails",
media_id[:10],
),
]
else:
return [
os.path.join(
self.base_path, "url_cache_thumbnails",
media_id[0:2], media_id[2:4], media_id[4:],
),
os.path.join(
self.base_path, "url_cache_thumbnails",
media_id[0:2], media_id[2:4],
),
os.path.join(
self.base_path, "url_cache_thumbnails",
media_id[0:2],
),
]

View File

@ -36,6 +36,9 @@ import cgi
import ujson as json import ujson as json
import urlparse import urlparse
import itertools import itertools
import datetime
import errno
import shutil
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -70,6 +73,10 @@ class PreviewUrlResource(Resource):
self.downloads = {} self.downloads = {}
self._cleaner_loop = self.clock.looping_call(
self._expire_url_cache_data, 10 * 1000
)
def render_GET(self, request): def render_GET(self, request):
self._async_render_GET(request) self._async_render_GET(request)
return NOT_DONE_YET return NOT_DONE_YET
@ -130,7 +137,7 @@ class PreviewUrlResource(Resource):
cache_result = yield self.store.get_url_cache(url, ts) cache_result = yield self.store.get_url_cache(url, ts)
if ( if (
cache_result and cache_result and
cache_result["download_ts"] + cache_result["expires"] > ts and cache_result["expires_ts"] > ts and
cache_result["response_code"] / 100 == 2 cache_result["response_code"] / 100 == 2
): ):
respond_with_json_bytes( respond_with_json_bytes(
@ -239,7 +246,7 @@ class PreviewUrlResource(Resource):
url, url,
media_info["response_code"], media_info["response_code"],
media_info["etag"], media_info["etag"],
media_info["expires"], media_info["expires"] + media_info["created_ts"],
json.dumps(og), json.dumps(og),
media_info["filesystem_id"], media_info["filesystem_id"],
media_info["created_ts"], media_info["created_ts"],
@ -253,8 +260,7 @@ class PreviewUrlResource(Resource):
# we're most likely being explicitly triggered by a human rather than a # we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot? # bot, so are we really a robot?
# XXX: horrible duplication with base_resource's _download_remote_file() file_id = datetime.date.today().isoformat() + '_' + random_string(16)
file_id = random_string(24)
fname = self.filepaths.url_cache_filepath(file_id) fname = self.filepaths.url_cache_filepath(file_id)
self.media_repo._makedirs(fname) self.media_repo._makedirs(fname)
@ -328,6 +334,88 @@ class PreviewUrlResource(Resource):
"etag": headers["ETag"][0] if "ETag" in headers else None, "etag": headers["ETag"][0] if "ETag" in headers else None,
}) })
@defer.inlineCallbacks
def _expire_url_cache_data(self):
"""Clean up expired url cache content, media and thumbnails.
"""
now = self.clock.time_msec()
# First we delete expired url cache entries
media_ids = yield self.store.get_expired_url_cache(now)
removed_media = []
for media_id in media_ids:
fname = self.filepaths.url_cache_filepath(media_id)
try:
os.remove(fname)
except OSError as e:
# If the path doesn't exist, meh
if e.errno != errno.ENOENT:
logger.warn("Failed to remove media: %r: %s", media_id, e)
continue
removed_media.append(media_id)
try:
dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
for dir in dirs:
os.rmdir(dir)
except:
pass
yield self.store.delete_url_cache(removed_media)
if removed_media:
logger.info("Deleted %d entries from url cache", len(removed_media))
# Now we delete old images associated with the url cache.
# These may be cached for a bit on the client (i.e., they
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
expire_before = now - 2 * 24 * 60 * 60 * 1000
media_ids = yield self.store.get_url_cache_media_before(expire_before)
removed_media = []
for media_id in media_ids:
fname = self.filepaths.url_cache_filepath(media_id)
try:
os.remove(fname)
except OSError as e:
# If the path doesn't exist, meh
if e.errno != errno.ENOENT:
logger.warn("Failed to remove media: %r: %s", media_id, e)
continue
try:
dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
for dir in dirs:
os.rmdir(dir)
except:
pass
thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id)
try:
shutil.rmtree(thumbnail_dir)
except OSError as e:
# If the path doesn't exist, meh
if e.errno != errno.ENOENT:
logger.warn("Failed to remove media: %r: %s", media_id, e)
continue
removed_media.append(media_id)
try:
dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id)
for dir in dirs:
os.rmdir(dir)
except:
pass
yield self.store.delete_url_cache_media(removed_media)
if removed_media:
logger.info("Deleted %d media from url cache", len(removed_media))
def decode_and_calc_og(body, media_uri, request_encoding=None): def decode_and_calc_og(body, media_uri, request_encoding=None):
from lxml import etree from lxml import etree

View File

@ -31,6 +31,7 @@ from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.events.spamcheck import SpamChecker
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.send_queue import FederationRemoteSendQueue
from synapse.federation.transport.client import TransportLayerClient from synapse.federation.transport.client import TransportLayerClient
@ -148,6 +149,7 @@ class HomeServer(object):
'groups_server_handler', 'groups_server_handler',
'groups_attestation_signing', 'groups_attestation_signing',
'groups_attestation_renewer', 'groups_attestation_renewer',
'spam_checker',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -333,6 +335,9 @@ class HomeServer(object):
def build_groups_attestation_renewer(self): def build_groups_attestation_renewer(self):
return GroupAttestionRenewer(self) return GroupAttestionRenewer(self)
def build_spam_checker(self):
return SpamChecker(self)
def remove_pusher(self, app_id, push_key, user_id): def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View File

@ -113,30 +113,37 @@ class KeyStore(SQLBaseStore):
keys[key_id] = key keys[key_id] = key
defer.returnValue(keys) defer.returnValue(keys)
@defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms, def store_server_verify_key(self, server_name, from_server, time_now_ms,
verify_key): verify_key):
"""Stores a NACL verification key for the given server. """Stores a NACL verification key for the given server.
Args: Args:
server_name (str): The name of the server. server_name (str): The name of the server.
key_id (str): The version of the key for the server.
from_server (str): Where the verification key was looked up from_server (str): Where the verification key was looked up
ts_now_ms (int): The time now in milliseconds time_now_ms (int): The time now in milliseconds
verification_key (VerifyKey): The NACL verify key. verify_key (nacl.signing.VerifyKey): The NACL verify key.
""" """
yield self._simple_upsert( key_id = "%s:%s" % (verify_key.alg, verify_key.version)
def _txn(txn):
self._simple_upsert_txn(
txn,
table="server_signature_keys", table="server_signature_keys",
keyvalues={ keyvalues={
"server_name": server_name, "server_name": server_name,
"key_id": "%s:%s" % (verify_key.alg, verify_key.version), "key_id": key_id,
}, },
values={ values={
"from_server": from_server, "from_server": from_server,
"ts_added_ms": time_now_ms, "ts_added_ms": time_now_ms,
"verify_key": buffer(verify_key.encode()), "verify_key": buffer(verify_key.encode()),
}, },
desc="store_server_verify_key",
) )
txn.call_after(
self._get_server_verify_key.invalidate,
(server_name, key_id)
)
return self.runInteraction("store_server_verify_key", _txn)
def store_server_keys_json(self, server_name, key_id, from_server, def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes): ts_now_ms, ts_expires_ms, key_json_bytes):

View File

@ -62,7 +62,7 @@ class MediaRepositoryStore(SQLBaseStore):
def get_url_cache_txn(txn): def get_url_cache_txn(txn):
# get the most recently cached result (relative to the given ts) # get the most recently cached result (relative to the given ts)
sql = ( sql = (
"SELECT response_code, etag, expires, og, media_id, download_ts" "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
" FROM local_media_repository_url_cache" " FROM local_media_repository_url_cache"
" WHERE url = ? AND download_ts <= ?" " WHERE url = ? AND download_ts <= ?"
" ORDER BY download_ts DESC LIMIT 1" " ORDER BY download_ts DESC LIMIT 1"
@ -74,7 +74,7 @@ class MediaRepositoryStore(SQLBaseStore):
# ...or if we've requested a timestamp older than the oldest # ...or if we've requested a timestamp older than the oldest
# copy in the cache, return the oldest copy (if any) # copy in the cache, return the oldest copy (if any)
sql = ( sql = (
"SELECT response_code, etag, expires, og, media_id, download_ts" "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
" FROM local_media_repository_url_cache" " FROM local_media_repository_url_cache"
" WHERE url = ? AND download_ts > ?" " WHERE url = ? AND download_ts > ?"
" ORDER BY download_ts ASC LIMIT 1" " ORDER BY download_ts ASC LIMIT 1"
@ -86,14 +86,14 @@ class MediaRepositoryStore(SQLBaseStore):
return None return None
return dict(zip(( return dict(zip((
'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts' 'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts'
), row)) ), row))
return self.runInteraction( return self.runInteraction(
"get_url_cache", get_url_cache_txn "get_url_cache", get_url_cache_txn
) )
def store_url_cache(self, url, response_code, etag, expires, og, media_id, def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id,
download_ts): download_ts):
return self._simple_insert( return self._simple_insert(
"local_media_repository_url_cache", "local_media_repository_url_cache",
@ -101,7 +101,7 @@ class MediaRepositoryStore(SQLBaseStore):
"url": url, "url": url,
"response_code": response_code, "response_code": response_code,
"etag": etag, "etag": etag,
"expires": expires, "expires_ts": expires_ts,
"og": og, "og": og,
"media_id": media_id, "media_id": media_id,
"download_ts": download_ts, "download_ts": download_ts,
@ -238,3 +238,64 @@ class MediaRepositoryStore(SQLBaseStore):
}, },
) )
return self.runInteraction("delete_remote_media", delete_remote_media_txn) return self.runInteraction("delete_remote_media", delete_remote_media_txn)
def get_expired_url_cache(self, now_ts):
sql = (
"SELECT media_id FROM local_media_repository_url_cache"
" WHERE expires_ts < ?"
" ORDER BY expires_ts ASC"
" LIMIT 500"
)
def _get_expired_url_cache_txn(txn):
txn.execute(sql, (now_ts,))
return [row[0] for row in txn]
return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
def delete_url_cache(self, media_ids):
sql = (
"DELETE FROM local_media_repository_url_cache"
" WHERE media_id = ?"
)
def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
def get_url_cache_media_before(self, before_ts):
sql = (
"SELECT media_id FROM local_media_repository"
" WHERE created_ts < ? AND url_cache IS NOT NULL"
" ORDER BY created_ts ASC"
" LIMIT 500"
)
def _get_url_cache_media_before_txn(txn):
txn.execute(sql, (before_ts,))
return [row[0] for row in txn]
return self.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn,
)
def delete_url_cache_media(self, media_ids):
def _delete_url_cache_media_txn(txn):
sql = (
"DELETE FROM local_media_repository"
" WHERE media_id = ?"
)
txn.executemany(sql, [(media_id,) for media_id in media_ids])
sql = (
"DELETE FROM local_media_repository_thumbnails"
" WHERE media_id = ?"
)
txn.executemany(sql, [(media_id,) for media_id in media_ids])
return self.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn,
)

View File

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 43 SCHEMA_VERSION = 44
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View File

@ -0,0 +1,38 @@
/* Copyright 2017 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support
-- indices on expressions until 3.9.
CREATE TABLE local_media_repository_url_cache_new(
url TEXT,
response_code INTEGER,
etag TEXT,
expires_ts BIGINT,
og TEXT,
media_id TEXT,
download_ts BIGINT
);
INSERT INTO local_media_repository_url_cache_new
SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache;
DROP TABLE local_media_repository_url_cache;
ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache;
CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts);
CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts);
CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id);

View File

@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from synapse.config._base import ConfigError
def load_module(provider):
""" Loads a module with its config
Take a dict with keys 'module' (the module name) and 'config'
(the config dict).
Returns
Tuple of (provider class, parsed config object)
"""
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
module, clz = provider['module'].rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
try:
provider_config = provider_class.parse_config(provider["config"])
except Exception as e:
raise ConfigError(
"Failed to parse config for %r: %r" % (provider['module'], e)
)
return provider_class, provider_config

View File

@ -12,17 +12,65 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import time
import signedjson.key
import signedjson.sign
from mock import Mock
from synapse.api.errors import SynapseError
from synapse.crypto import keyring from synapse.crypto import keyring
from synapse.util import async, logcontext
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from tests import utils, unittest from tests import unittest, utils
from twisted.internet import defer from twisted.internet import defer
class MockPerspectiveServer(object):
def __init__(self):
self.server_name = "mock_server"
self.key = signedjson.key.generate_signing_key(0)
def get_verify_keys(self):
vk = signedjson.key.get_verify_key(self.key)
return {
"%s:%s" % (vk.alg, vk.version): vk,
}
def get_signed_key(self, server_name, verify_key):
key_id = "%s:%s" % (verify_key.alg, verify_key.version)
res = {
"server_name": server_name,
"old_verify_keys": {},
"valid_until_ts": time.time() * 1000 + 3600,
"verify_keys": {
key_id: {
"key": signedjson.key.encode_verify_key_base64(verify_key)
}
}
}
signedjson.sign.sign_json(res, self.server_name, self.key)
return res
class KeyringTestCase(unittest.TestCase): class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.hs = yield utils.setup_test_homeserver(handlers=None) self.mock_perspective_server = MockPerspectiveServer()
self.http_client = Mock()
self.hs = yield utils.setup_test_homeserver(
handlers=None,
http_client=self.http_client,
)
self.hs.config.perspectives = {
self.mock_perspective_server.server_name:
self.mock_perspective_server.get_verify_keys()
}
def check_context(self, _, expected):
self.assertEquals(
getattr(LoggingContext.current_context(), "test_key", None),
expected
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_wait_for_previous_lookups(self): def test_wait_for_previous_lookups(self):
@ -30,11 +78,6 @@ class KeyringTestCase(unittest.TestCase):
kr = keyring.Keyring(self.hs) kr = keyring.Keyring(self.hs)
def check_context(_, expected):
self.assertEquals(
LoggingContext.current_context().test_key, expected
)
lookup_1_deferred = defer.Deferred() lookup_1_deferred = defer.Deferred()
lookup_2_deferred = defer.Deferred() lookup_2_deferred = defer.Deferred()
@ -50,7 +93,7 @@ class KeyringTestCase(unittest.TestCase):
self.assertTrue(wait_1_deferred.called) self.assertTrue(wait_1_deferred.called)
# ... so we should have preserved the LoggingContext. # ... so we should have preserved the LoggingContext.
self.assertIs(LoggingContext.current_context(), context_one) self.assertIs(LoggingContext.current_context(), context_one)
wait_1_deferred.addBoth(check_context, "one") wait_1_deferred.addBoth(self.check_context, "one")
with LoggingContext("two") as context_two: with LoggingContext("two") as context_two:
context_two.test_key = "two" context_two.test_key = "two"
@ -64,7 +107,7 @@ class KeyringTestCase(unittest.TestCase):
self.assertFalse(wait_2_deferred.called) self.assertFalse(wait_2_deferred.called)
# ... so we should have reset the LoggingContext. # ... so we should have reset the LoggingContext.
self.assertIs(LoggingContext.current_context(), sentinel_context) self.assertIs(LoggingContext.current_context(), sentinel_context)
wait_2_deferred.addBoth(check_context, "two") wait_2_deferred.addBoth(self.check_context, "two")
# let the first lookup complete (in the sentinel context) # let the first lookup complete (in the sentinel context)
lookup_1_deferred.callback(None) lookup_1_deferred.callback(None)
@ -72,3 +115,115 @@ class KeyringTestCase(unittest.TestCase):
# now the second wait should complete and restore our # now the second wait should complete and restore our
# loggingcontext. # loggingcontext.
yield wait_2_deferred yield wait_2_deferred
@defer.inlineCallbacks
def test_verify_json_objects_for_server_awaits_previous_requests(self):
key1 = signedjson.key.generate_signing_key(1)
kr = keyring.Keyring(self.hs)
json1 = {}
signedjson.sign.sign_json(json1, "server10", key1)
persp_resp = {
"server_keys": [
self.mock_perspective_server.get_signed_key(
"server10",
signedjson.key.get_verify_key(key1)
),
]
}
persp_deferred = defer.Deferred()
@defer.inlineCallbacks
def get_perspectives(**kwargs):
self.assertEquals(
LoggingContext.current_context().test_key, "11",
)
with logcontext.PreserveLoggingContext():
yield persp_deferred
defer.returnValue(persp_resp)
self.http_client.post_json.side_effect = get_perspectives
with LoggingContext("11") as context_11:
context_11.test_key = "11"
# start off a first set of lookups
res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1),
("server11", {})
]
)
# the unsigned json should be rejected pretty quickly
self.assertTrue(res_deferreds[1].called)
try:
yield res_deferreds[1]
self.assertFalse("unsigned json didn't cause a failure")
except SynapseError:
pass
self.assertFalse(res_deferreds[0].called)
res_deferreds[0].addBoth(self.check_context, None)
# wait a tick for it to send the request to the perspectives server
# (it first tries the datastore)
yield async.sleep(0.005)
self.http_client.post_json.assert_called_once()
self.assertIs(LoggingContext.current_context(), context_11)
context_12 = LoggingContext("12")
context_12.test_key = "12"
with logcontext.PreserveLoggingContext(context_12):
# a second request for a server with outstanding requests
# should block rather than start a second call
self.http_client.post_json.reset_mock()
self.http_client.post_json.return_value = defer.Deferred()
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1)],
)
yield async.sleep(0.005)
self.http_client.post_json.assert_not_called()
res_deferreds_2[0].addBoth(self.check_context, None)
# complete the first request
with logcontext.PreserveLoggingContext():
persp_deferred.callback(persp_resp)
self.assertIs(LoggingContext.current_context(), context_11)
with logcontext.PreserveLoggingContext():
yield res_deferreds[0]
yield res_deferreds_2[0]
@defer.inlineCallbacks
def test_verify_json_for_server(self):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
yield self.hs.datastore.store_server_verify_key(
"server9", "", time.time() * 1000,
signedjson.key.get_verify_key(key1),
)
json1 = {}
signedjson.sign.sign_json(json1, "server9", key1)
sentinel_context = LoggingContext.current_context()
with LoggingContext("one") as context_one:
context_one.test_key = "one"
defer = kr.verify_json_for_server("server9", {})
try:
yield defer
self.fail("should fail on unsigned json")
except SynapseError:
pass
self.assertIs(LoggingContext.current_context(), context_one)
defer = kr.verify_json_for_server("server9", json1)
self.assertFalse(defer.called)
self.assertIs(LoggingContext.current_context(), sentinel_context)
yield defer
self.assertIs(LoggingContext.current_context(), context_one)

View File

@ -24,15 +24,17 @@ from synapse.http.endpoint import resolve_service
from tests.utils import MockClock from tests.utils import MockClock
@unittest.DEBUG
class DnsTestCase(unittest.TestCase): class DnsTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve(self): def test_resolve(self):
dns_client_mock = Mock() dns_client_mock = Mock()
service_name = "test_service.examle.com" service_name = "test_service.example.com"
host_name = "example.com" host_name = "example.com"
ip_address = "127.0.0.1" ip_address = "127.0.0.1"
ip6_address = "::1"
answer_srv = dns.RRHeader( answer_srv = dns.RRHeader(
type=dns.SRV, type=dns.SRV,
@ -48,8 +50,22 @@ class DnsTestCase(unittest.TestCase):
) )
) )
dns_client_mock.lookupService.return_value = ([answer_srv], None, None) answer_aaaa = dns.RRHeader(
dns_client_mock.lookupAddress.return_value = ([answer_a], None, None) type=dns.AAAA,
payload=dns.Record_AAAA(
address=ip6_address,
)
)
dns_client_mock.lookupService.return_value = defer.succeed(
([answer_srv], None, None),
)
dns_client_mock.lookupAddress.return_value = defer.succeed(
([answer_a], None, None),
)
dns_client_mock.lookupIPV6Address.return_value = defer.succeed(
([answer_aaaa], None, None),
)
cache = {} cache = {}
@ -59,10 +75,12 @@ class DnsTestCase(unittest.TestCase):
dns_client_mock.lookupService.assert_called_once_with(service_name) dns_client_mock.lookupService.assert_called_once_with(service_name)
dns_client_mock.lookupAddress.assert_called_once_with(host_name) dns_client_mock.lookupAddress.assert_called_once_with(host_name)
dns_client_mock.lookupIPV6Address.assert_called_once_with(host_name)
self.assertEquals(len(servers), 1) self.assertEquals(len(servers), 2)
self.assertEquals(servers, cache[service_name]) self.assertEquals(servers, cache[service_name])
self.assertEquals(servers[0].host, ip_address) self.assertEquals(servers[0].host, ip_address)
self.assertEquals(servers[1].host, ip6_address)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_from_cache_expired_and_dns_fail(self): def test_from_cache_expired_and_dns_fail(self):

View File

@ -56,6 +56,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.worker_replication_url = "" config.worker_replication_url = ""
config.worker_app = None config.worker_app = None
config.email_enable_notifs = False config.email_enable_notifs = False
config.block_non_admin_invites = False
config.use_frozen_dicts = True config.use_frozen_dicts = True
config.database_config = {"name": "sqlite3"} config.database_config = {"name": "sqlite3"}