Merge branch 'develop' of github.com:matrix-org/synapse into erikj/msc_1813

This commit is contained in:
Erik Johnston 2019-01-25 09:57:01 +00:00
commit 62514bb81b
63 changed files with 1455 additions and 1021 deletions

View File

@ -1,11 +1,7 @@
[run] [run]
branch = True branch = True
parallel = True parallel = True
source = synapse include = synapse/*
[paths]
source=
coverage
[report] [report]
precision = 2 precision = 2

6
.gitignore vendored
View File

@ -25,9 +25,9 @@ homeserver*.pid
*.tls.dh *.tls.dh
*.tls.key *.tls.key
.coverage .coverage*
.coverage.* coverage.*
!.coverage.rc !.coveragerc
htmlcov htmlcov
demo/*/*.db demo/*/*.db

1
changelog.d/4306.misc Normal file
View File

@ -0,0 +1 @@
Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+ and SQLite 3.24+.

1
changelog.d/4384.feature Normal file
View File

@ -0,0 +1 @@
Synapse can now automatically provision TLS certificates via ACME (the protocol used by CAs like Let's Encrypt).

1
changelog.d/4405.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug when rejecting remote invites

1
changelog.d/4428.misc Normal file
View File

@ -0,0 +1 @@
Move SRV logic into the Agent layer

1
changelog.d/4432.misc Normal file
View File

@ -0,0 +1 @@
Apply a unique index to the user_ips table, preventing duplicates.

1
changelog.d/4433.misc Normal file
View File

@ -0,0 +1 @@
debian package: symlink to explicit python version

1
changelog.d/4435.bugfix Normal file
View File

@ -0,0 +1 @@
Fix None guard in calling config.server.is_threepid_reserved

View File

@ -1 +1 @@
Add support for persisting event format versions in the database Add infrastructure to support different event formats

1
changelog.d/4444.misc Normal file
View File

@ -0,0 +1 @@
Generate the debian config during build

1
changelog.d/4445.feature Normal file
View File

@ -0,0 +1 @@
Add a metric for tracking event stream position of the user directory.

1
changelog.d/4448.misc Normal file
View File

@ -0,0 +1 @@
Add infrastructure to support different event formats

1
changelog.d/4452.bugfix Normal file
View File

@ -0,0 +1 @@
Don't send IP addresses as SNI

1
changelog.d/4458.misc Normal file
View File

@ -0,0 +1 @@
Clarify documentation for the `public_baseurl` config param

1
changelog.d/4459.misc Normal file
View File

@ -0,0 +1 @@
Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+ and SQLite 3.24+.

1
changelog.d/4460.bugfix Normal file
View File

@ -0,0 +1 @@
Fix UnboundLocalError in post_urlencoded_get_json

1
changelog.d/4461.bugfix Normal file
View File

@ -0,0 +1 @@
Add a timeout to filtered room directory queries.

1
changelog.d/4464.misc Normal file
View File

@ -0,0 +1 @@
Move SRV logic into the Agent layer

View File

@ -6,7 +6,16 @@
set -e set -e
export DH_VIRTUALENV_INSTALL_ROOT=/opt/venvs export DH_VIRTUALENV_INSTALL_ROOT=/opt/venvs
SNAKE=/usr/bin/python3
# make sure that the virtualenv links to the specific version of python, by
# dereferencing the python3 symlink.
#
# Otherwise, if somebody tries to install (say) the stretch package on buster,
# they will get a confusing error about "No module named 'synapse'", because
# python won't look in the right directory. At least this way, the error will
# be a *bit* more obvious.
#
SNAKE=`readlink -e /usr/bin/python3`
# try to set the CFLAGS so any compiled C extensions are compiled with the most # try to set the CFLAGS so any compiled C extensions are compiled with the most
# generic as possible x64 instructions, so that compiling it on a new Intel chip # generic as possible x64 instructions, so that compiling it on a new Intel chip
@ -36,6 +45,10 @@ dh_virtualenv \
--extra-pip-arg="--compile" \ --extra-pip-arg="--compile" \
--extras="all" --extras="all"
PACKAGE_BUILD_DIR="debian/matrix-synapse-py3"
VIRTUALENV_DIR="${PACKAGE_BUILD_DIR}${DH_VIRTUALENV_INSTALL_ROOT}/matrix-synapse"
TARGET_PYTHON="${VIRTUALENV_DIR}/bin/python"
# we copy the tests to a temporary directory so that we can put them on the # we copy the tests to a temporary directory so that we can put them on the
# PYTHONPATH without putting the uninstalled synapse on the pythonpath. # PYTHONPATH without putting the uninstalled synapse on the pythonpath.
tmpdir=`mktemp -d` tmpdir=`mktemp -d`
@ -44,5 +57,35 @@ trap "rm -r $tmpdir" EXIT
cp -r tests "$tmpdir" cp -r tests "$tmpdir"
PYTHONPATH="$tmpdir" \ PYTHONPATH="$tmpdir" \
debian/matrix-synapse-py3/opt/venvs/matrix-synapse/bin/python \ "${TARGET_PYTHON}" -B -m twisted.trial --reporter=text -j2 tests
-B -m twisted.trial --reporter=text -j2 tests
# build the config file
"${TARGET_PYTHON}" -B "${VIRTUALENV_DIR}/bin/generate_config" \
--config-dir="/etc/matrix-synapse" \
--data-dir="/var/lib/matrix-synapse" |
perl -pe '
# tweak the paths to the tls certs and signing keys
/^tls_.*_path:/ and s/SERVERNAME/homeserver/;
/^signing_key_path:/ and s/SERVERNAME/homeserver/;
# tweak the pid file location
/^pid_file:/ and s#:.*#: "/var/run/matrix-synapse.pid"#;
# tweak the path to the log config
/^log_config:/ and s/SERVERNAME\.log\.config/log.yaml/;
# tweak the path to the media store
/^media_store_path:/ and s#/media_store#/media#;
# remove the server_name setting, which is set in a separate file
/^server_name:/ and $_ = "#\n# This is set in /etc/matrix-synapse/conf.d/server_name.yaml for Debian installations.\n# $_";
# remove the report_stats setting, which is set in a separate file
/^# report_stats:/ and $_ = "";
' > "${PACKAGE_BUILD_DIR}/etc/matrix-synapse/homeserver.yaml"
# add a dependency on the right version of python to substvars.
PYPKG=`basename $SNAKE`
echo "synapse:pydepends=$PYPKG" >> debian/matrix-synapse-py3.substvars

2
debian/control vendored
View File

@ -27,8 +27,8 @@ Depends:
adduser, adduser,
debconf, debconf,
python3-distutils|libpython3-stdlib (<< 3.6), python3-distutils|libpython3-stdlib (<< 3.6),
python3,
${misc:Depends}, ${misc:Depends},
${synapse:pydepends},
# some of our scripts use perl, but none of them are important, # some of our scripts use perl, but none of them are important,
# so we put perl:Depends in Suggests rather than Depends. # so we put perl:Depends in Suggests rather than Depends.
Suggests: Suggests:

614
debian/homeserver.yaml vendored
View File

@ -1,614 +0,0 @@
# vim:ft=yaml
# PEM encoded X509 certificate for TLS.
# You can replace the self-signed certificate that synapse
# autogenerates on launch with your own SSL certificate + key pair
# if you like. Any required intermediary certificates can be
# appended after the primary certificate in hierarchical order.
tls_certificate_path: "/etc/matrix-synapse/homeserver.tls.crt"
# PEM encoded private key for TLS
tls_private_key_path: "/etc/matrix-synapse/homeserver.tls.key"
# Don't bind to the https port
no_tls: False
# List of allowed TLS fingerprints for this server to publish along
# with the signing keys for this server. Other matrix servers that
# make HTTPS requests to this server will check that the TLS
# certificates returned by this server match one of the fingerprints.
#
# Synapse automatically adds the fingerprint of its own certificate
# to the list. So if federation traffic is handled directly by synapse
# then no modification to the list is required.
#
# If synapse is run behind a load balancer that handles the TLS then it
# will be necessary to add the fingerprints of the certificates used by
# the loadbalancers to this list if they are different to the one
# synapse is using.
#
# Homeservers are permitted to cache the list of TLS fingerprints
# returned in the key responses up to the "valid_until_ts" returned in
# key. It may be necessary to publish the fingerprints of a new
# certificate and wait until the "valid_until_ts" of the previous key
# responses have passed before deploying it.
#
# You can calculate a fingerprint from a given TLS listener via:
# openssl s_client -connect $host:$port < /dev/null 2> /dev/null |
# openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '='
# or by checking matrix.org/federationtester/api/report?server_name=$host
#
tls_fingerprints: []
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
## Server ##
# When running as a daemon, the file to store the pid in
pid_file: "/var/run/matrix-synapse.pid"
# CPU affinity mask. Setting this restricts the CPUs on which the
# process will be scheduled. It is represented as a bitmask, with the
# lowest order bit corresponding to the first logical CPU and the
# highest order bit corresponding to the last logical CPU. Not all CPUs
# may exist on a given system but a mask may specify more CPUs than are
# present.
#
# For example:
# 0x00000001 is processor #0,
# 0x00000003 is processors #0 and #1,
# 0xFFFFFFFF is all processors (#0 through #31).
#
# Pinning a Python process to a single CPU is desirable, because Python
# is inherently single-threaded due to the GIL, and can suffer a
# 30-40% slowdown due to cache blow-out and thread context switching
# if the scheduler happens to schedule the underlying threads across
# different cores. See
# https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/.
#
# cpu_affinity: 0xFFFFFFFF
# The path to the web client which will be served at /_matrix/client/
# if 'webclient' is configured under the 'listeners' configuration.
#
# web_client_location: "/path/to/web/root"
# The public-facing base URL for the client API (not including _matrix/...)
# public_baseurl: https://example.com:8448/
# Set the soft limit on the number of file descriptors synapse can use
# Zero is used to indicate synapse should set the soft limit to the
# hard limit.
soft_file_limit: 0
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
# gc_thresholds: [700, 10, 10]
# Set the limit on the returned events in the timeline in the get
# and sync operations. The default value is -1, means no upper limit.
# filter_timeline_limit: 5000
# Whether room invites to users on this server should be blocked
# (except those sent by local server admins). The default is False.
# block_non_admin_invites: True
# Restrict federation to the following whitelist of domains.
# N.B. we recommend also firewalling your federation listener to limit
# inbound federation traffic as early as possible, rather than relying
# purely on this application-layer restriction. If not specified, the
# default is to whitelist everything.
#
# federation_domain_whitelist:
# - lon.example.com
# - nyc.example.com
# - syd.example.com
# List of ports that Synapse should listen on, their purpose and their
# configuration.
listeners:
# Main HTTPS listener
# For when matrix traffic is sent directly to synapse.
-
# The port to listen for HTTPS requests on.
port: 8448
# Local addresses to listen on.
# On Linux and Mac OS, `::` will listen on all IPv4 and IPv6
# addresses by default. For most other OSes, this will only listen
# on IPv6.
bind_addresses:
- '::'
- '0.0.0.0'
# This is a 'http' listener, allows us to specify 'resources'.
type: http
tls: true
# Use the X-Forwarded-For (XFF) header as the client IP and not the
# actual client IP.
x_forwarded: false
# List of HTTP resources to serve on this listener.
resources:
-
# List of resources to host on this listener.
names:
- client # The client-server APIs, both v1 and v2
- webclient # The bundled webclient.
# Should synapse compress HTTP responses to clients that support it?
# This should be disabled if running synapse behind a load balancer
# that can do automatic compression.
compress: true
- names: [federation] # Federation APIs
compress: false
# optional list of additional endpoints which can be loaded via
# dynamic modules
# additional_resources:
# "/_matrix/my/custom/endpoint":
# module: my_module.CustomRequestHandler
# config: {}
# Unsecure HTTP listener,
# For when matrix traffic passes through loadbalancer that unwraps TLS.
- port: 8008
tls: false
bind_addresses: ['::', '0.0.0.0']
type: http
x_forwarded: false
resources:
- names: [client, webclient]
compress: true
- names: [federation]
compress: false
# Turn on the twisted ssh manhole service on localhost on the given
# port.
# - port: 9000
# bind_addresses: ['::1', '127.0.0.1']
# type: manhole
# Database configuration
database:
# The database engine name
name: "sqlite3"
# Arguments to pass to the engine
args:
# Path to the database
database: "/var/lib/matrix-synapse/homeserver.db"
# Number of events to cache in memory.
event_cache_size: "10K"
# A yaml python logging config file
log_config: "/etc/matrix-synapse/log.yaml"
## Ratelimiting ##
# Number of messages a client can send per second
rc_messages_per_second: 0.2
# Number of message a client can send before being throttled
rc_message_burst_count: 10.0
# The federation window size in milliseconds
federation_rc_window_size: 1000
# The number of federation requests from a single server in a window
# before the server will delay processing the request.
federation_rc_sleep_limit: 10
# The duration in milliseconds to delay processing events from
# remote servers by if they go over the sleep limit.
federation_rc_sleep_delay: 500
# The maximum number of concurrent federation requests allowed
# from a single server
federation_rc_reject_limit: 50
# The number of federation requests to concurrently process from a
# single server
federation_rc_concurrent: 3
# Directory where uploaded images and attachments are stored.
media_store_path: "/var/lib/matrix-synapse/media"
# Media storage providers allow media to be stored in different
# locations.
# media_storage_providers:
# - module: file_system
# # Whether to write new local files.
# store_local: false
# # Whether to write new remote media
# store_remote: false
# # Whether to block upload requests waiting for write to this
# # provider to complete
# store_synchronous: false
# config:
# directory: /mnt/some/other/directory
# Directory where in-progress uploads are stored.
uploads_path: "/var/lib/matrix-synapse/uploads"
# The largest allowed upload size in bytes
max_upload_size: "10M"
# Maximum number of pixels that will be thumbnailed
max_image_pixels: "32M"
# Whether to generate new thumbnails on the fly to precisely match
# the resolution requested by the client. If true then whenever
# a new resolution is requested by the client the server will
# generate a new thumbnail. If false the server will pick a thumbnail
# from a precalculated list.
dynamic_thumbnails: false
# List of thumbnail to precalculate when an image is uploaded.
thumbnail_sizes:
- width: 32
height: 32
method: crop
- width: 96
height: 96
method: crop
- width: 320
height: 240
method: scale
- width: 640
height: 480
method: scale
- width: 800
height: 600
method: scale
# Is the preview URL API enabled? If enabled, you *must* specify
# an explicit url_preview_ip_range_blacklist of IPs that the spider is
# denied from accessing.
url_preview_enabled: False
# List of IP address CIDR ranges that the URL preview spider is denied
# from accessing. There are no defaults: you must explicitly
# specify a list for URL previewing to work. You should specify any
# internal services in your network that you do not want synapse to try
# to connect to, otherwise anyone in any Matrix room could cause your
# synapse to issue arbitrary GET requests to your internal services,
# causing serious security issues.
#
# url_preview_ip_range_blacklist:
# - '127.0.0.0/8'
# - '10.0.0.0/8'
# - '172.16.0.0/12'
# - '192.168.0.0/16'
# - '100.64.0.0/10'
# - '169.254.0.0/16'
#
# List of IP address CIDR ranges that the URL preview spider is allowed
# to access even if they are specified in url_preview_ip_range_blacklist.
# This is useful for specifying exceptions to wide-ranging blacklisted
# target IP ranges - e.g. for enabling URL previews for a specific private
# website only visible in your network.
#
# url_preview_ip_range_whitelist:
# - '192.168.1.1'
# Optional list of URL matches that the URL preview spider is
# denied from accessing. You should use url_preview_ip_range_blacklist
# in preference to this, otherwise someone could define a public DNS
# entry that points to a private IP address and circumvent the blacklist.
# This is more useful if you know there is an entire shape of URL that
# you know that will never want synapse to try to spider.
#
# Each list entry is a dictionary of url component attributes as returned
# by urlparse.urlsplit as applied to the absolute form of the URL. See
# https://docs.python.org/2/library/urlparse.html#urlparse.urlsplit
# The values of the dictionary are treated as an filename match pattern
# applied to that component of URLs, unless they start with a ^ in which
# case they are treated as a regular expression match. If all the
# specified component matches for a given list item succeed, the URL is
# blacklisted.
#
# url_preview_url_blacklist:
# # blacklist any URL with a username in its URI
# - username: '*'
#
# # blacklist all *.google.com URLs
# - netloc: 'google.com'
# - netloc: '*.google.com'
#
# # blacklist all plain HTTP URLs
# - scheme: 'http'
#
# # blacklist http(s)://www.acme.com/foo
# - netloc: 'www.acme.com'
# path: '/foo'
#
# # blacklist any URL with a literal IPv4 address
# - netloc: '^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$'
# The largest allowed URL preview spidering size in bytes
max_spider_size: "10M"
## Captcha ##
# See docs/CAPTCHA_SETUP for full details of configuring this.
# This Home Server's ReCAPTCHA public key.
recaptcha_public_key: "YOUR_PUBLIC_KEY"
# This Home Server's ReCAPTCHA private key.
recaptcha_private_key: "YOUR_PRIVATE_KEY"
# Enables ReCaptcha checks when registering, preventing signup
# unless a captcha is answered. Requires a valid ReCaptcha
# public/private key.
enable_registration_captcha: False
# A secret key used to bypass the captcha test entirely.
#captcha_bypass_secret: "YOUR_SECRET_HERE"
# The API endpoint to use for verifying m.login.recaptcha responses.
recaptcha_siteverify_api: "https://www.google.com/recaptcha/api/siteverify"
## Turn ##
# The public URIs of the TURN server to give to clients
turn_uris: []
# The shared secret used to compute passwords for the TURN server
turn_shared_secret: "YOUR_SHARED_SECRET"
# The Username and password if the TURN server needs them and
# does not use a token
#turn_username: "TURNSERVER_USERNAME"
#turn_password: "TURNSERVER_PASSWORD"
# How long generated TURN credentials last
turn_user_lifetime: "1h"
# Whether guests should be allowed to use the TURN server.
# This defaults to True, otherwise VoIP will be unreliable for guests.
# However, it does introduce a slight security risk as it allows users to
# connect to arbitrary endpoints without having first signed up for a
# valid account (e.g. by passing a CAPTCHA).
turn_allow_guests: False
## Registration ##
# Enable registration for new users.
enable_registration: False
# The user must provide all of the below types of 3PID when registering.
#
# registrations_require_3pid:
# - email
# - msisdn
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#
# allowed_local_3pids:
# - medium: email
# pattern: ".*@matrix\.org"
# - medium: email
# pattern: ".*@vector\.im"
# - medium: msisdn
# pattern: "\+44"
# If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled.
# registration_shared_secret: <PRIVATE STRING>
# Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash.
# The default number is 12 (which equates to 2^12 rounds).
# N.B. that increasing this will exponentially increase the time required
# to register or login - e.g. 24 => 2^24 rounds which will take >20 mins.
bcrypt_rounds: 12
# Allows users to register as guests without a password/email/etc, and
# participate in rooms hosted on this server which have been made
# accessible to anonymous users.
allow_guest_access: False
# The list of identity servers trusted to verify third party
# identifiers by this server.
trusted_third_party_id_servers:
- matrix.org
- vector.im
- riot.im
# Users who register on this homeserver will automatically be joined
# to these rooms
#auto_join_rooms:
# - "#example:example.com"
## Metrics ###
# Enable collection and rendering of performance metrics
enable_metrics: False
## API Configuration ##
# A list of event types that will be included in the room_invite_state
room_invite_state_types:
- "m.room.join_rules"
- "m.room.canonical_alias"
- "m.room.avatar"
- "m.room.name"
# A list of application service config file to use
app_service_config_files: []
# macaroon_secret_key: <PRIVATE STRING>
# Used to enable access token expiration.
expire_access_token: False
## Signing Keys ##
# Path to the signing key to sign messages with
signing_key_path: "/etc/matrix-synapse/homeserver.signing.key"
# The keys that the server used to sign messages with but won't use
# to sign new messages. E.g. it has lost its private key
old_signing_keys: {}
# "ed25519:auto":
# # Base64 encoded public key
# key: "The public part of your old signing key."
# # Millisecond POSIX timestamp when the key expired.
# expired_ts: 123456789123
# How long key response published by this server is valid for.
# Used to set the valid_until_ts in /key/v2 APIs.
# Determines how quickly servers will query to check which keys
# are still valid.
key_refresh_interval: "1d" # 1 Day.
# The trusted servers to download signing keys from.
perspectives:
servers:
"matrix.org":
verify_keys:
"ed25519:auto":
key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
# Enable SAML2 for registration and login. Uses pysaml2
# config_path: Path to the sp_conf.py configuration file
# idp_redirect_url: Identity provider URL which will redirect
# the user back to /login/saml2 with proper info.
# See pysaml2 docs for format of config.
#saml2_config:
# enabled: true
# config_path: "/home/erikj/git/synapse/sp_conf.py"
# idp_redirect_url: "http://test/idp"
# Enable CAS for registration and login.
#cas_config:
# enabled: true
# server_url: "https://cas-server.com"
# service_url: "https://homeserver.domain.com:8448"
# #required_attributes:
# # name: value
# The JWT needs to contain a globally unique "sub" (subject) claim.
#
# jwt_config:
# enabled: true
# secret: "a secret"
# algorithm: "HS256"
# Enable password for login.
password_config:
enabled: true
# Uncomment and change to a secret random string for extra security.
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
#pepper: ""
# Enable sending emails for notification events
# Defining a custom URL for Riot is only needed if email notifications
# should contain links to a self-hosted installation of Riot; when set
# the "app_name" setting is ignored.
#
# If your SMTP server requires authentication, the optional smtp_user &
# smtp_pass variables should be used
#
#email:
# enable_notifs: false
# smtp_host: "localhost"
# smtp_port: 25
# smtp_user: "exampleusername"
# smtp_pass: "examplepassword"
# require_transport_security: False
# notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>"
# app_name: Matrix
# template_dir: res/templates
# notif_template_html: notif_mail.html
# notif_template_text: notif_mail.txt
# notif_for_new_users: True
# riot_base_url: "http://localhost/riot"
# password_providers:
# - module: "ldap_auth_provider.LdapAuthProvider"
# config:
# enabled: true
# uri: "ldap://ldap.example.com:389"
# start_tls: true
# base: "ou=users,dc=example,dc=com"
# attributes:
# uid: "cn"
# mail: "email"
# name: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
# Clients requesting push notifications can either have the body of
# the message sent in the notification poke along with other details
# like the sender, or just the event ID and room ID (`event_id_only`).
# If clients choose the former, this option controls whether the
# notification request includes the content of the event (other details
# like the sender are still included). For `event_id_only` push, it
# has no effect.
# For modern android devices the notification content will still appear
# because it is loaded by the app. iPhone, however will send a
# notification saying only that a message arrived and who it came from.
#
#push:
# include_content: true
# spam_checker:
# module: "my_custom_project.SuperSpamChecker"
# config:
# example_option: 'things'
# Whether to allow non server admins to create groups on this server
enable_group_creation: false
# If enabled, non server admins can only create groups with local parts
# starting with this prefix
# group_creation_prefix: "unofficial/"
# User Directory configuration
#
# 'search_all_users' defines whether to search all users visible to your HS
# when searching the user directory, rather than limiting to users visible
# in public rooms. Defaults to false. If you set it True, you'll have to run
# UPDATE user_directory_stream_pos SET stream_id = NULL;
# on your database to tell it to rebuild the user_directory search indexes.
#
#user_directory:
# search_all_users: false

1
debian/install vendored
View File

@ -1,2 +1 @@
debian/homeserver.yaml etc/matrix-synapse
debian/log.yaml etc/matrix-synapse debian/log.yaml etc/matrix-synapse

View File

@ -10,12 +10,12 @@
# can be passed on the commandline for debugging. # can be passed on the commandline for debugging.
import argparse import argparse
from concurrent.futures import ThreadPoolExecutor
import os import os
import signal import signal
import subprocess import subprocess
import sys import sys
import threading import threading
from concurrent.futures import ThreadPoolExecutor
DISTS = ( DISTS = (
"debian:stretch", "debian:stretch",

View File

@ -819,7 +819,9 @@ class Auth(object):
elif threepid: elif threepid:
# If the user does not exist yet, but is signing up with a # If the user does not exist yet, but is signing up with a
# reserved threepid then pass auth check # reserved threepid then pass auth check
if is_threepid_reserved(self.hs.config, threepid): if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid
):
return return
# Else if there is no room in the MAU bucket, bail # Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count() current_mau = yield self.store.get_monthly_active_count()

View File

@ -13,10 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc import gc
import logging import logging
import os import os
import sys import sys
import traceback
from six import iteritems from six import iteritems
@ -324,17 +326,12 @@ def setup(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts
tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
hs = SynapseHomeServer( hs = SynapseHomeServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
@ -361,12 +358,53 @@ def setup(config_options):
logger.info("Database prepared in %s.", config.database_config['name']) logger.info("Database prepared in %s.", config.database_config['name'])
hs.setup() hs.setup()
hs.start_listening()
@defer.inlineCallbacks
def start(): def start():
hs.get_pusherpool().start() try:
hs.get_datastore().start_profiling() # Check if the certificate is still valid.
hs.get_datastore().start_doing_background_updates() cert_days_remaining = hs.config.is_disk_cert_valid()
if hs.config.acme_enabled:
# If ACME is enabled, we might need to provision a certificate
# before starting.
acme = hs.get_acme_handler()
# Start up the webservices which we will respond to ACME
# challenges with.
yield acme.start_listening()
# We want to reprovision if cert_days_remaining is None (meaning no
# certificate exists), or the days remaining number it returns
# is less than our re-registration threshold.
if (cert_days_remaining is None) or (
not cert_days_remaining > hs.config.acme_reprovision_threshold
):
yield acme.provision_certificate()
# Read the certificate from disk and build the context factories for
# TLS.
hs.config.read_certificate_from_disk()
hs.tls_server_context_factory = context_factory.ServerContextFactory(config)
hs.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
config
)
# It is now safe to start your Synapse.
hs.start_listening()
hs.get_pusherpool().start()
hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates()
except Exception as e:
# If a DeferredList failed (like in listening on the ACME listener),
# we need to print the subfailure explicitly.
if isinstance(e, defer.FirstError):
e.subFailure.printTraceback(sys.stderr)
sys.exit(1)
# Something else went wrong when starting. Print it and bail out.
traceback.print_exc(file=sys.stderr)
sys.exit(1)
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View File

@ -367,7 +367,7 @@ class Config(object):
if not keys_directory: if not keys_directory:
keys_directory = os.path.dirname(config_files[-1]) keys_directory = os.path.dirname(config_files[-1])
config_dir_path = os.path.abspath(keys_directory) self.config_dir_path = os.path.abspath(keys_directory)
specified_config = {} specified_config = {}
for config_file in config_files: for config_file in config_files:
@ -379,7 +379,7 @@ class Config(object):
server_name = specified_config["server_name"] server_name = specified_config["server_name"]
config_string = self.generate_config( config_string = self.generate_config(
config_dir_path=config_dir_path, config_dir_path=self.config_dir_path,
data_dir_path=os.getcwd(), data_dir_path=os.getcwd(),
server_name=server_name, server_name=server_name,
generate_secrets=False, generate_secrets=False,

View File

@ -256,7 +256,11 @@ class ServerConfig(Config):
# #
# web_client_location: "/path/to/web/root" # web_client_location: "/path/to/web/root"
# The public-facing base URL for the client API (not including _matrix/...) # The public-facing base URL that clients use to access this HS
# (not including _matrix/...). This is the same URL a user would
# enter into the 'custom HS URL' field on their client. If you
# use synapse with a reverse proxy, this should be the URL to reach
# synapse via the proxy.
# public_baseurl: https://example.com:8448/ # public_baseurl: https://example.com:8448/
# Set the soft limit on the number of file descriptors synapse can use # Set the soft limit on the number of file descriptors synapse can use
@ -420,19 +424,18 @@ class ServerConfig(Config):
" service on the given port.") " service on the given port.")
def is_threepid_reserved(config, threepid): def is_threepid_reserved(reserved_threepids, threepid):
"""Check the threepid against the reserved threepid config """Check the threepid against the reserved threepid config
Args: Args:
config(ServerConfig) - to access server config attributes reserved_threepids([dict]) - list of reserved threepids
threepid(dict) - The threepid to test for threepid(dict) - The threepid to test for
Returns: Returns:
boolean Is the threepid undertest reserved_user boolean Is the threepid undertest reserved_user
""" """
for tp in config.mau_limits_reserved_threepids: for tp in reserved_threepids:
if (threepid['medium'] == tp['medium'] if (threepid['medium'] == tp['medium'] and threepid['address'] == tp['address']):
and threepid['address'] == tp['address']):
return True return True
return False return False

View File

@ -13,45 +13,38 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import os import os
from datetime import datetime
from hashlib import sha256 from hashlib import sha256
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
from OpenSSL import crypto from OpenSSL import crypto
from ._base import Config from synapse.config._base import Config
logger = logging.getLogger()
class TlsConfig(Config): class TlsConfig(Config):
def read_config(self, config): def read_config(self, config):
self.tls_certificate = self.read_tls_certificate(
config.get("tls_certificate_path")
)
self.tls_certificate_file = config.get("tls_certificate_path")
acme_config = config.get("acme", {})
self.acme_enabled = acme_config.get("enabled", False)
self.acme_url = acme_config.get(
"url", "https://acme-v01.api.letsencrypt.org/directory"
)
self.acme_port = acme_config.get("port", 8449)
self.acme_bind_addresses = acme_config.get("bind_addresses", ["127.0.0.1"])
self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30)
self.tls_certificate_file = os.path.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = os.path.abspath(config.get("tls_private_key_path"))
self._original_tls_fingerprints = config["tls_fingerprints"]
self.tls_fingerprints = list(self._original_tls_fingerprints)
self.no_tls = config.get("no_tls", False) self.no_tls = config.get("no_tls", False)
if self.no_tls:
self.tls_private_key = None
else:
self.tls_private_key = self.read_tls_private_key(
config.get("tls_private_key_path")
)
self.tls_fingerprints = config["tls_fingerprints"]
# Check that our own certificate is included in the list of fingerprints
# and include it if it is not.
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1,
self.tls_certificate
)
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
if sha256_fingerprint not in sha256_fingerprints:
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
# This config option applies to non-federation HTTP clients # This config option applies to non-federation HTTP clients
# (e.g. for talking to recaptcha, identity servers, and such) # (e.g. for talking to recaptcha, identity servers, and such)
# It should never be used in production, and is intended for # It should never be used in production, and is intended for
@ -60,13 +53,70 @@ class TlsConfig(Config):
"use_insecure_ssl_client_just_for_testing_do_not_use" "use_insecure_ssl_client_just_for_testing_do_not_use"
) )
self.tls_certificate = None
self.tls_private_key = None
def is_disk_cert_valid(self):
"""
Is the certificate we have on disk valid, and if so, for how long?
Returns:
int: Days remaining of certificate validity.
None: No certificate exists.
"""
if not os.path.exists(self.tls_certificate_file):
return None
try:
with open(self.tls_certificate_file, 'rb') as f:
cert_pem = f.read()
except Exception:
logger.exception("Failed to read existing certificate off disk!")
raise
try:
tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
except Exception:
logger.exception("Failed to parse existing certificate off disk!")
raise
# YYYYMMDDhhmmssZ -- in UTC
expires_on = datetime.strptime(
tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ"
)
now = datetime.utcnow()
days_remaining = (expires_on - now).days
return days_remaining
def read_certificate_from_disk(self):
"""
Read the certificates from disk.
"""
self.tls_certificate = self.read_tls_certificate(self.tls_certificate_file)
if not self.no_tls:
self.tls_private_key = self.read_tls_private_key(self.tls_private_key_file)
self.tls_fingerprints = list(self._original_tls_fingerprints)
# Check that our own certificate is included in the list of fingerprints
# and include it if it is not.
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, self.tls_certificate
)
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
if sha256_fingerprint not in sha256_fingerprints:
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
def default_config(self, config_dir_path, server_name, **kwargs): def default_config(self, config_dir_path, server_name, **kwargs):
base_key_name = os.path.join(config_dir_path, server_name) base_key_name = os.path.join(config_dir_path, server_name)
tls_certificate_path = base_key_name + ".tls.crt" tls_certificate_path = base_key_name + ".tls.crt"
tls_private_key_path = base_key_name + ".tls.key" tls_private_key_path = base_key_name + ".tls.key"
return """\ return (
"""\
# PEM encoded X509 certificate for TLS. # PEM encoded X509 certificate for TLS.
# You can replace the self-signed certificate that synapse # You can replace the self-signed certificate that synapse
# autogenerates on launch with your own SSL certificate + key pair # autogenerates on launch with your own SSL certificate + key pair
@ -107,7 +157,24 @@ class TlsConfig(Config):
# #
tls_fingerprints: [] tls_fingerprints: []
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}] # tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
""" % locals()
## Support for ACME certificate auto-provisioning.
# acme:
# enabled: false
## ACME path.
## If you only want to test, use the staging url:
## https://acme-staging.api.letsencrypt.org/directory
# url: 'https://acme-v01.api.letsencrypt.org/directory'
## Port number (to listen for the HTTP-01 challenge).
## Using port 80 requires utilising something like authbind, or proxying to it.
# port: 8449
## Hosts to bind to.
# bind_addresses: ['127.0.0.1']
## How many days remaining on a certificate before it is renewed.
# reprovision_threshold: 30
"""
% locals()
)
def read_tls_certificate(self, cert_path): def read_tls_certificate(self, cert_path):
cert_pem = self.read_file(cert_path, "tls_certificate") cert_pem = self.read_file(cert_path, "tls_certificate")

View File

@ -17,6 +17,7 @@ from zope.interface import implementer
from OpenSSL import SSL, crypto from OpenSSL import SSL, crypto
from twisted.internet._sslverify import _defaultCurveName from twisted.internet._sslverify import _defaultCurveName
from twisted.internet.abstract import isIPAddress, isIPv6Address
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
from twisted.internet.ssl import CertificateOptions, ContextFactory from twisted.internet.ssl import CertificateOptions, ContextFactory
from twisted.python.failure import Failure from twisted.python.failure import Failure
@ -98,8 +99,14 @@ class ClientTLSOptions(object):
def __init__(self, hostname, ctx): def __init__(self, hostname, ctx):
self._ctx = ctx self._ctx = ctx
self._hostname = hostname
self._hostnameBytes = _idnaBytes(hostname) if isIPAddress(hostname) or isIPv6Address(hostname):
self._hostnameBytes = hostname.encode('ascii')
self._sendSNI = False
else:
self._hostnameBytes = _idnaBytes(hostname)
self._sendSNI = True
ctx.set_info_callback( ctx.set_info_callback(
_tolerateErrors(self._identityVerifyingInfoCallback) _tolerateErrors(self._identityVerifyingInfoCallback)
) )
@ -111,7 +118,9 @@ class ClientTLSOptions(object):
return connection return connection
def _identityVerifyingInfoCallback(self, connection, where, ret): def _identityVerifyingInfoCallback(self, connection, where, ret):
if where & SSL.SSL_CB_HANDSHAKE_START: # Literal IPv4 and IPv6 addresses are not permitted
# as host names according to the RFCs
if where & SSL.SSL_CB_HANDSHAKE_START and self._sendSNI:
connection.set_tlsext_host_name(self._hostnameBytes) connection.set_tlsext_host_name(self._hostnameBytes)

View File

@ -42,8 +42,13 @@ class _EventInternalMetadata(object):
def is_outlier(self): def is_outlier(self):
return getattr(self, "outlier", False) return getattr(self, "outlier", False)
def is_invite_from_remote(self): def is_out_of_band_membership(self):
return getattr(self, "invite_from_remote", False) """Whether this is an out of band membership, like an invite or an invite
rejection. This is needed as those events are marked as outliers, but
they still need to be processed as if they're new events (e.g. updating
invite state in the database, relaying to clients, etc).
"""
return getattr(self, "out_of_band_membership", False)
def get_send_on_behalf_of(self): def get_send_on_behalf_of(self):
"""Whether this server should send the event on behalf of another server. """Whether this server should send the event on behalf of another server.

View File

@ -43,8 +43,8 @@ class FederationBase(object):
self._clock = hs.get_clock() self._clock = hs.get_clock()
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False, def _check_sigs_and_hash_and_fetch(self, origin, pdus, room_version,
include_none=False): outlier=False, include_none=False):
"""Takes a list of PDUs and checks the signatures and hashs of each """Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of the database and if not then request if from the originating server of
@ -56,8 +56,12 @@ class FederationBase(object):
a new list. a new list.
Args: Args:
origin (str)
pdu (list) pdu (list)
outlier (bool) room_version (str)
outlier (bool): Whether the events are outliers or not
include_none (str): Whether to include None in the returned list
for events that have failed their checks
Returns: Returns:
Deferred : A list of PDUs that have valid signatures and hashes. Deferred : A list of PDUs that have valid signatures and hashes.
@ -84,6 +88,7 @@ class FederationBase(object):
res = yield self.get_pdu( res = yield self.get_pdu(
destinations=[pdu.origin], destinations=[pdu.origin],
event_id=pdu.event_id, event_id=pdu.event_id,
room_version=room_version,
outlier=outlier, outlier=outlier,
timeout=10000, timeout=10000,
) )

View File

@ -37,7 +37,8 @@ from synapse.api.errors import (
HttpResponseException, HttpResponseException,
SynapseError, SynapseError,
) )
from synapse.events import builder, room_version_to_event_format from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events import room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -71,6 +72,8 @@ class FederationClient(FederationBase):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client() self.transport_layer = hs.get_federation_transport_client()
self.event_builder_factory = hs.get_event_builder_factory()
self._get_pdu_cache = ExpiringCache( self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache", cache_name="get_pdu_cache",
clock=self._clock, clock=self._clock,
@ -207,7 +210,8 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_pdu(self, destinations, event_id, outlier=False, timeout=None): def get_pdu(self, destinations, event_id, room_version, outlier=False,
timeout=None):
"""Requests the PDU with given origin and ID from the remote home """Requests the PDU with given origin and ID from the remote home
servers. servers.
@ -217,6 +221,7 @@ class FederationClient(FederationBase):
Args: Args:
destinations (list): Which home servers to query destinations (list): Which home servers to query
event_id (str): event to fetch event_id (str): event to fetch
room_version (str): version of the room
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False` of the current block of PDUs. Defaults to `False`
@ -357,10 +362,13 @@ class FederationClient(FederationBase):
ev.event_id for ev in itertools.chain(pdus, auth_chain) ev.event_id for ev in itertools.chain(pdus, auth_chain)
]) ])
room_version = yield self.store.get_room_version(room_id)
signed_pdus = yield self._check_sigs_and_hash_and_fetch( signed_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, destination,
[p for p in pdus if p.event_id not in seen_events], [p for p in pdus if p.event_id not in seen_events],
outlier=True outlier=True,
room_version=room_version,
) )
signed_pdus.extend( signed_pdus.extend(
seen_events[p.event_id] for p in pdus if p.event_id in seen_events seen_events[p.event_id] for p in pdus if p.event_id in seen_events
@ -369,7 +377,8 @@ class FederationClient(FederationBase):
signed_auth = yield self._check_sigs_and_hash_and_fetch( signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, destination,
[p for p in auth_chain if p.event_id not in seen_events], [p for p in auth_chain if p.event_id not in seen_events],
outlier=True outlier=True,
room_version=room_version,
) )
signed_auth.extend( signed_auth.extend(
seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
@ -416,6 +425,8 @@ class FederationClient(FederationBase):
random.shuffle(srvs) random.shuffle(srvs)
return srvs return srvs
room_version = yield self.store.get_room_version(room_id)
batch_size = 20 batch_size = 20
missing_events = list(missing_events) missing_events = list(missing_events)
for i in range(0, len(missing_events), batch_size): for i in range(0, len(missing_events), batch_size):
@ -426,6 +437,7 @@ class FederationClient(FederationBase):
self.get_pdu, self.get_pdu,
destinations=random_server_list(), destinations=random_server_list(),
event_id=e_id, event_id=e_id,
room_version=room_version,
) )
for e_id in batch for e_id in batch
] ]
@ -455,8 +467,11 @@ class FederationClient(FederationBase):
for p in res["auth_chain"] for p in res["auth_chain"]
] ]
room_version = yield self.store.get_room_version(room_id)
signed_auth = yield self._check_sigs_and_hash_and_fetch( signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True destination, auth_chain,
outlier=True, room_version=room_version,
) )
signed_auth.sort(key=lambda e: e.depth) signed_auth.sort(key=lambda e: e.depth)
@ -527,6 +542,8 @@ class FederationClient(FederationBase):
Does so by asking one of the already participating servers to create an Does so by asking one of the already participating servers to create an
event with proper context. event with proper context.
Returns a fully signed and hashed event.
Note that this does not append any events to any graphs. Note that this does not append any events to any graphs.
Args: Args:
@ -541,7 +558,7 @@ class FederationClient(FederationBase):
params (dict[str, str|Iterable[str]]): Query parameters to include in the params (dict[str, str|Iterable[str]]): Query parameters to include in the
request. request.
Return: Return:
Deferred[tuple[str, dict, int]]: resolves to a tuple of Deferred[tuple[str, FrozenEvent, int]]: resolves to a tuple of
`(origin, event, event_format)` where origin is the remote `(origin, event, event_format)` where origin is the remote
homeserver which generated the event, and event_format is one of homeserver which generated the event, and event_format is one of
`synapse.api.constants.EventFormatVersions`. `synapse.api.constants.EventFormatVersions`.
@ -583,7 +600,18 @@ class FederationClient(FederationBase):
if "prev_state" not in pdu_dict: if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = [] pdu_dict["prev_state"] = []
ev = builder.EventBuilder(pdu_dict) # Strip off the fields that we want to clobber.
pdu_dict.pop("origin", None)
pdu_dict.pop("origin_server_ts", None)
pdu_dict.pop("unsigned", None)
builder = self.event_builder_factory.new(pdu_dict)
add_hashes_and_signatures(
builder,
self.hs.hostname,
self.hs.config.signing_key[0]
)
ev = builder.build()
defer.returnValue( defer.returnValue(
(destination, ev, event_format) (destination, ev, event_format)
@ -662,9 +690,21 @@ class FederationClient(FederationBase):
for p in itertools.chain(state, auth_chain) for p in itertools.chain(state, auth_chain)
} }
room_version = None
for e in state:
if (e.type, e.state_key) == (EventTypes.Create, ""):
room_version = e.content.get("room_version", RoomVersions.V1)
break
if room_version is None:
# If the state doesn't have a create event then the room is
# invalid, and it would fail auth checks anyway.
raise SynapseError(400, "No create event in state")
valid_pdus = yield self._check_sigs_and_hash_and_fetch( valid_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, list(pdus.values()), destination, list(pdus.values()),
outlier=True, outlier=True,
room_version=room_version,
) )
valid_pdus_map = { valid_pdus_map = {
@ -802,8 +842,10 @@ class FederationClient(FederationBase):
for e in content["auth_chain"] for e in content["auth_chain"]
] ]
room_version = yield self.store.get_room_version(room_id)
signed_auth = yield self._check_sigs_and_hash_and_fetch( signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True destination, auth_chain, outlier=True, room_version=room_version,
) )
signed_auth.sort(key=lambda e: e.depth) signed_auth.sort(key=lambda e: e.depth)
@ -850,8 +892,10 @@ class FederationClient(FederationBase):
for e in content.get("events", []) for e in content.get("events", [])
] ]
room_version = yield self.store.get_room_version(room_id)
signed_events = yield self._check_sigs_and_hash_and_fetch( signed_events = yield self._check_sigs_and_hash_and_fetch(
destination, events, outlier=False destination, events, outlier=False, room_version=room_version,
) )
except HttpResponseException as e: except HttpResponseException as e:
if not e.code == 400: if not e.code == 400:

View File

@ -463,8 +463,10 @@ class FederationServer(FederationBase):
for e in content["auth_chain"] for e in content["auth_chain"]
] ]
room_version = yield self.store.get_room_version(room_id)
signed_auth = yield self._check_sigs_and_hash_and_fetch( signed_auth = yield self._check_sigs_and_hash_and_fetch(
origin, auth_chain, outlier=True origin, auth_chain, outlier=True, room_version=room_version,
) )
ret = yield self.handler.on_query_auth( ret = yield self.handler.on_query_auth(

147
synapse/handlers/acme.py Normal file
View File

@ -0,0 +1,147 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import attr
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import serverFromString
from twisted.python.filepath import FilePath
from twisted.python.url import URL
from twisted.web import server, static
from twisted.web.resource import Resource
logger = logging.getLogger(__name__)
try:
from txacme.interfaces import ICertificateStore
@attr.s
@implementer(ICertificateStore)
class ErsatzStore(object):
"""
A store that only stores in memory.
"""
certs = attr.ib(default=attr.Factory(dict))
def store(self, server_name, pem_objects):
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
return defer.succeed(None)
except ImportError:
# txacme is missing
pass
class AcmeHandler(object):
def __init__(self, hs):
self.hs = hs
self.reactor = hs.get_reactor()
@defer.inlineCallbacks
def start_listening(self):
# Configure logging for txacme, if you need to debug
# from eliot import add_destinations
# from eliot.twisted import TwistedDestination
#
# add_destinations(TwistedDestination())
from txacme.challenges import HTTP01Responder
from txacme.service import AcmeIssuingService
from txacme.endpoint import load_or_create_client_key
from txacme.client import Client
from josepy.jwa import RS256
self._store = ErsatzStore()
responder = HTTP01Responder()
self._issuer = AcmeIssuingService(
cert_store=self._store,
client_creator=(
lambda: Client.from_url(
reactor=self.reactor,
url=URL.from_text(self.hs.config.acme_url),
key=load_or_create_client_key(
FilePath(self.hs.config.config_dir_path)
),
alg=RS256,
)
),
clock=self.reactor,
responders=[responder],
)
well_known = Resource()
well_known.putChild(b'acme-challenge', responder.resource)
responder_resource = Resource()
responder_resource.putChild(b'.well-known', well_known)
responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain'))
srv = server.Site(responder_resource)
listeners = []
for host in self.hs.config.acme_bind_addresses:
logger.info(
"Listening for ACME requests on %s:%s", host, self.hs.config.acme_port
)
endpoint = serverFromString(
self.reactor, "tcp:%s:interface=%s" % (self.hs.config.acme_port, host)
)
listeners.append(endpoint.listen(srv))
# Make sure we are registered to the ACME server. There's no public API
# for this, it is usually triggered by startService, but since we don't
# want it to control where we save the certificates, we have to reach in
# and trigger the registration machinery ourselves.
self._issuer._registered = False
yield self._issuer._ensure_registered()
# Return a Deferred that will fire when all the servers have started up.
yield defer.DeferredList(listeners, fireOnOneErrback=True, consumeErrors=True)
@defer.inlineCallbacks
def provision_certificate(self):
logger.warning("Reprovisioning %s", self.hs.hostname)
try:
yield self._issuer.issue_cert(self.hs.hostname)
except Exception:
logger.exception("Fail!")
raise
logger.warning("Reprovisioned %s, saving.", self.hs.hostname)
cert_chain = self._store.certs[self.hs.hostname]
try:
with open(self.hs.config.tls_private_key_file, "wb") as private_key_file:
for x in cert_chain:
if x.startswith(b"-----BEGIN RSA PRIVATE KEY-----"):
private_key_file.write(x)
with open(self.hs.config.tls_certificate_file, "wb") as certificate_file:
for x in cert_chain:
if x.startswith(b"-----BEGIN CERTIFICATE-----"):
certificate_file.write(x)
except Exception:
logger.exception("Failed saving!")
raise
defer.returnValue(True)

View File

@ -34,6 +34,7 @@ from synapse.api.constants import (
EventTypes, EventTypes,
Membership, Membership,
RejectedReason, RejectedReason,
RoomVersions,
) )
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
@ -43,10 +44,7 @@ from synapse.api.errors import (
StoreError, StoreError,
SynapseError, SynapseError,
) )
from synapse.crypto.event_signing import ( from synapse.crypto.event_signing import compute_event_signature
add_hashes_and_signatures,
compute_event_signature,
)
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.replication.http.federation import ( from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet, ReplicationCleanRoomRestServlet,
@ -58,7 +56,6 @@ from synapse.types import UserID, get_domain_from_id
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room from synapse.util.distributor import user_joined_room
from synapse.util.frozenutils import unfreeze
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from synapse.visibility import filter_events_for_server from synapse.visibility import filter_events_for_server
@ -342,6 +339,8 @@ class FederationHandler(BaseHandler):
room_id, event_id, p, room_id, event_id, p,
) )
room_version = yield self.store.get_room_version(room_id)
with logcontext.nested_logging_context(p): with logcontext.nested_logging_context(p):
# note that if any of the missing prevs share missing state or # note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped # auth events, the requests to fetch those events are deduped
@ -355,7 +354,7 @@ class FederationHandler(BaseHandler):
# we want the state *after* p; get_state_for_room returns the # we want the state *after* p; get_state_for_room returns the
# state *before* p. # state *before* p.
remote_event = yield self.federation_client.get_pdu( remote_event = yield self.federation_client.get_pdu(
[origin], p, outlier=True, [origin], p, room_version, outlier=True,
) )
if remote_event is None: if remote_event is None:
@ -379,7 +378,6 @@ class FederationHandler(BaseHandler):
for x in remote_state: for x in remote_state:
event_map[x.event_id] = x event_map[x.event_id] = x
room_version = yield self.store.get_room_version(room_id)
state_map = yield resolve_events_with_store( state_map = yield resolve_events_with_store(
room_version, state_maps, event_map, room_version, state_maps, event_map,
state_res_store=StateResolutionStore(self.store), state_res_store=StateResolutionStore(self.store),
@ -655,6 +653,8 @@ class FederationHandler(BaseHandler):
if dest == self.server_name: if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.") raise SynapseError(400, "Can't backfill from self.")
room_version = yield self.store.get_room_version(room_id)
events = yield self.federation_client.backfill( events = yield self.federation_client.backfill(
dest, dest,
room_id, room_id,
@ -748,6 +748,7 @@ class FederationHandler(BaseHandler):
self.federation_client.get_pdu, self.federation_client.get_pdu,
[dest], [dest],
event_id, event_id,
room_version=room_version,
outlier=True, outlier=True,
timeout=10000, timeout=10000,
) )
@ -1083,7 +1084,6 @@ class FederationHandler(BaseHandler):
handled_events = set() handled_events = set()
try: try:
event = self._sign_event(event)
# Try the host we successfully got a response to /make_join/ # Try the host we successfully got a response to /make_join/
# request first. # request first.
try: try:
@ -1287,7 +1287,7 @@ class FederationHandler(BaseHandler):
) )
event.internal_metadata.outlier = True event.internal_metadata.outlier = True
event.internal_metadata.invite_from_remote = True event.internal_metadata.out_of_band_membership = True
event.signatures.update( event.signatures.update(
compute_event_signature( compute_event_signature(
@ -1313,7 +1313,7 @@ class FederationHandler(BaseHandler):
# Mark as outlier as we don't have any state for this event; we're not # Mark as outlier as we don't have any state for this event; we're not
# even in the room. # even in the room.
event.internal_metadata.outlier = True event.internal_metadata.outlier = True
event = self._sign_event(event) event.internal_metadata.out_of_band_membership = True
# Try the host that we succesfully called /make_leave/ on first for # Try the host that we succesfully called /make_leave/ on first for
# the /send_leave/ request. # the /send_leave/ request.
@ -1357,27 +1357,6 @@ class FederationHandler(BaseHandler):
assert(event.room_id == room_id) assert(event.room_id == room_id)
defer.returnValue((origin, event)) defer.returnValue((origin, event))
def _sign_event(self, event):
event.internal_metadata.outlier = False
builder = self.event_builder_factory.new(
unfreeze(event.get_pdu_json())
)
builder.event_id = self.event_builder_factory.create_event_id()
builder.origin = self.hs.hostname
if not hasattr(event, "signatures"):
builder.signatures = {}
add_hashes_and_signatures(
builder,
self.hs.hostname,
self.hs.config.signing_key[0],
)
return builder.build()
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_make_leave_request(self, room_id, user_id): def on_make_leave_request(self, room_id, user_id):
@ -1659,6 +1638,13 @@ class FederationHandler(BaseHandler):
create_event = e create_event = e
break break
if create_event is None:
# If the state doesn't have a create event then the room is
# invalid, and it would fail auth checks anyway.
raise SynapseError(400, "No create event in state")
room_version = create_event.content.get("room_version", RoomVersions.V1)
missing_auth_events = set() missing_auth_events = set()
for e in itertools.chain(auth_events, state, [event]): for e in itertools.chain(auth_events, state, [event]):
for e_id in e.auth_event_ids(): for e_id in e.auth_event_ids():
@ -1669,6 +1655,7 @@ class FederationHandler(BaseHandler):
m_ev = yield self.federation_client.get_pdu( m_ev = yield self.federation_client.get_pdu(
[origin], [origin],
e_id, e_id,
room_version=room_version,
outlier=True, outlier=True,
timeout=10000, timeout=10000,
) )

View File

@ -73,8 +73,14 @@ class RoomListHandler(BaseHandler):
# We explicitly don't bother caching searches or requests for # We explicitly don't bother caching searches or requests for
# appservice specific lists. # appservice specific lists.
logger.info("Bypassing cache as search request.") logger.info("Bypassing cache as search request.")
# XXX: Quick hack to stop room directory queries taking too long.
# Timeout request after 60s. Probably want a more fundamental
# solution at some point
timeout = self.clock.time() + 60
return self._get_public_room_list( return self._get_public_room_list(
limit, since_token, search_filter, network_tuple=network_tuple, limit, since_token, search_filter,
network_tuple=network_tuple, timeout=timeout,
) )
key = (limit, since_token, network_tuple) key = (limit, since_token, network_tuple)
@ -87,7 +93,8 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None, def _get_public_room_list(self, limit=None, since_token=None,
search_filter=None, search_filter=None,
network_tuple=EMPTY_THIRD_PARTY_ID,): network_tuple=EMPTY_THIRD_PARTY_ID,
timeout=None,):
if since_token and since_token != "END": if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token) since_token = RoomListNextBatch.from_token(since_token)
else: else:
@ -202,6 +209,9 @@ class RoomListHandler(BaseHandler):
chunk = [] chunk = []
for i in range(0, len(rooms_to_scan), step): for i in range(0, len(rooms_to_scan), step):
if timeout and self.clock.time() > timeout:
raise Exception("Timed out searching room directory")
batch = rooms_to_scan[i:i + step] batch = rooms_to_scan[i:i + step]
logger.info("Processing %i rooms for result", len(batch)) logger.info("Processing %i rooms for result", len(batch))
yield concurrently_execute( yield concurrently_execute(

View File

@ -19,6 +19,7 @@ from six import iteritems
from twisted.internet import defer from twisted.internet import defer
import synapse.metrics
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
@ -163,6 +164,11 @@ class UserDirectoryHandler(object):
yield self._handle_deltas(deltas) yield self._handle_deltas(deltas)
self.pos = deltas[-1]["stream_id"] self.pos = deltas[-1]["stream_id"]
# Expose current event processing position to prometheus
synapse.metrics.event_processing_positions.labels(
"user_dir").set(self.pos)
yield self.store.update_user_directory_stream_pos(self.pos) yield self.store.update_user_directory_stream_pos(self.pos)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -333,9 +333,10 @@ class SimpleHttpClient(object):
"POST", uri, headers=Headers(actual_headers), data=query_bytes "POST", uri, headers=Headers(actual_headers), data=query_bytes
) )
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300: if 200 <= response.code < 300:
body = yield make_deferred_yieldable(treq.json_content(response)) defer.returnValue(json.loads(body))
defer.returnValue(body)
else: else:
raise HttpResponseException(response.code, response.phrase, body) raise HttpResponseException(response.code, response.phrase, body)

View File

@ -13,15 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import random
import re import re
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.error import ConnectError
from synapse.http.federation.srv_resolver import Server, resolve_service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -88,140 +81,3 @@ def parse_and_validate_server_name(server_name):
)) ))
return host, port return host, port
def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None,
timeout=None):
"""Construct an endpoint for the given matrix destination.
Args:
reactor: Twisted reactor.
destination (unicode): The name of the server to connect to.
tls_client_options_factory
(synapse.crypto.context_factory.ClientTLSOptionsFactory):
Factory which generates TLS options for client connections.
timeout (int): connection timeout in seconds
"""
domain, port = parse_server_name(destination)
endpoint_kw_args = {}
if timeout is not None:
endpoint_kw_args.update(timeout=timeout)
if tls_client_options_factory is None:
transport_endpoint = HostnameEndpoint
default_port = 8008
else:
# the SNI string should be the same as the Host header, minus the port.
# as per https://github.com/matrix-org/synapse/issues/2525#issuecomment-336896777,
# the Host header and SNI should therefore be the server_name of the remote
# server.
tls_options = tls_client_options_factory.get_options(domain)
def transport_endpoint(reactor, host, port, timeout):
return wrapClientTLS(
tls_options,
HostnameEndpoint(reactor, host, port, timeout=timeout),
)
default_port = 8448
if port is None:
return SRVClientEndpoint(
reactor, "matrix", domain, protocol="tcp",
default_port=default_port, endpoint=transport_endpoint,
endpoint_kw_args=endpoint_kw_args
)
else:
return transport_endpoint(
reactor, domain, port, **endpoint_kw_args
)
class SRVClientEndpoint(object):
"""An endpoint which looks up SRV records for a service.
Cycles through the list of servers starting with each call to connect
picking the next server.
Implements twisted.internet.interfaces.IStreamClientEndpoint.
"""
def __init__(self, reactor, service, domain, protocol="tcp",
default_port=None, endpoint=HostnameEndpoint,
endpoint_kw_args={}):
self.reactor = reactor
self.service_name = "_%s._%s.%s" % (service, protocol, domain)
if default_port is not None:
self.default_server = Server(
host=domain,
port=default_port,
)
else:
self.default_server = None
self.endpoint = endpoint
self.endpoint_kw_args = endpoint_kw_args
self.servers = None
self.used_servers = None
@defer.inlineCallbacks
def fetch_servers(self):
self.used_servers = []
self.servers = yield resolve_service(self.service_name)
def pick_server(self):
if not self.servers:
if self.used_servers:
self.servers = self.used_servers
self.used_servers = []
self.servers.sort()
elif self.default_server:
return self.default_server
else:
raise ConnectError(
"No server available for %s" % self.service_name
)
# look for all servers with the same priority
min_priority = self.servers[0].priority
weight_indexes = list(
(index, server.weight + 1)
for index, server in enumerate(self.servers)
if server.priority == min_priority
)
total_weight = sum(weight for index, weight in weight_indexes)
target_weight = random.randint(0, total_weight)
for index, weight in weight_indexes:
target_weight -= weight
if target_weight <= 0:
server = self.servers[index]
# XXX: this looks totally dubious:
#
# (a) we never reuse a server until we have been through
# all of the servers at the same priority, so if the
# weights are A: 100, B:1, we always do ABABAB instead of
# AAAA...AAAB (approximately).
#
# (b) After using all the servers at the lowest priority,
# we move onto the next priority. We should only use the
# second priority if servers at the top priority are
# unreachable.
#
del self.servers[index]
self.used_servers.append(server)
return server
@defer.inlineCallbacks
def connect(self, protocolFactory):
if self.servers is None:
yield self.fetch_servers()
server = self.pick_server()
logger.info("Connecting to %s:%s", server.host, server.port)
endpoint = self.endpoint(
self.reactor, server.host, server.port, **self.endpoint_kw_args
)
connection = yield endpoint.connect(protocolFactory)
defer.returnValue(connection)

View File

@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.iweb import IAgent
from synapse.http.endpoint import parse_server_name
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__)
@implementer(IAgent)
class MatrixFederationAgent(object):
"""An Agent-like thing which provides a `request` method which will look up a matrix
server and send an HTTP request to it.
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
Args:
reactor (IReactor): twisted reactor to use for underlying requests
tls_client_options_factory (ClientTLSOptionsFactory|None):
factory to use for fetching client tls options, or none to disable TLS.
srv_resolver (SrvResolver|None):
SRVResolver impl to use for looking up SRV records. None to use a default
implementation.
"""
def __init__(
self, reactor, tls_client_options_factory, _srv_resolver=None,
):
self._reactor = reactor
self._tls_client_options_factory = tls_client_options_factory
if _srv_resolver is None:
_srv_resolver = SrvResolver()
self._srv_resolver = _srv_resolver
self._pool = HTTPConnectionPool(reactor)
self._pool.retryAutomatically = False
self._pool.maxPersistentPerHost = 5
self._pool.cachedConnectionTimeout = 2 * 60
@defer.inlineCallbacks
def request(self, method, uri, headers=None, bodyProducer=None):
"""
Args:
method (bytes): HTTP method: GET/POST/etc
uri (bytes): Absolute URI to be retrieved
headers (twisted.web.http_headers.Headers|None):
HTTP headers to send with the request, or None to
send no extra headers.
bodyProducer (twisted.web.iweb.IBodyProducer|None):
An object which can generate bytes to make up the
body of this request (for example, the properly encoded contents of
a file for a file upload). Or None if the request is to have
no body.
Returns:
Deferred[twisted.web.iweb.IResponse]:
fires when the header of the response has been received (regardless of the
response status code). Fails if there is any problem which prevents that
response from being received (including problems that prevent the request
from being sent).
"""
parsed_uri = URI.fromBytes(uri)
server_name_bytes = parsed_uri.netloc
host, port = parse_server_name(server_name_bytes.decode("ascii"))
# XXX disabling TLS is really only supported here for the benefit of the
# unit tests. We should make the UTs cope with TLS rather than having to make
# the code support the unit tests.
if self._tls_client_options_factory is None:
tls_options = None
else:
tls_options = self._tls_client_options_factory.get_options(host)
if port is not None:
target = (host, port)
else:
service_name = b"_matrix._tcp.%s" % (server_name_bytes, )
server_list = yield self._srv_resolver.resolve_service(service_name)
if not server_list:
target = (host, 8448)
logger.debug("No SRV record for %s, using %s", host, target)
else:
target = pick_server_from_list(server_list)
class EndpointFactory(object):
@staticmethod
def endpointForURI(_uri):
logger.info("Connecting to %s:%s", target[0], target[1])
ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1])
if tls_options is not None:
ep = wrapClientTLS(tls_options, ep)
return ep
agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
res = yield make_deferred_yieldable(
agent.request(method, uri, headers, bodyProducer)
)
defer.returnValue(res)

View File

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import random
import time import time
import attr import attr
@ -51,74 +52,118 @@ class Server(object):
expires = attr.ib(default=0) expires = attr.ib(default=0)
@defer.inlineCallbacks def pick_server_from_list(server_list):
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time): """Randomly choose a server from the server list
"""Look up a SRV record, with caching
Args:
server_list (list[Server]): list of candidate servers
Returns:
Tuple[bytes, int]: (host, port) pair for the chosen server
"""
if not server_list:
raise RuntimeError("pick_server_from_list called with empty list")
# TODO: currently we only use the lowest-priority servers. We should maintain a
# cache of servers known to be "down" and filter them out
min_priority = min(s.priority for s in server_list)
eligible_servers = list(s for s in server_list if s.priority == min_priority)
total_weight = sum(s.weight for s in eligible_servers)
target_weight = random.randint(0, total_weight)
for s in eligible_servers:
target_weight -= s.weight
if target_weight <= 0:
return s.host, s.port
# this should be impossible.
raise RuntimeError(
"pick_server_from_list got to end of eligible server list.",
)
class SrvResolver(object):
"""Interface to the dns client to do SRV lookups, with result caching.
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver, The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
but the cache never gets populated), so we add our own caching layer here. but the cache never gets populated), so we add our own caching layer here.
Args: Args:
service_name (unicode|bytes): record to look up
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
cache (dict): cache object cache (dict): cache object
clock (object): clock implementation. must provide a time() method. get_time (callable): clock implementation. Should return seconds since the epoch
Returns:
Deferred[list[Server]]: a list of the SRV records, or an empty list if none found
""" """
# TODO: the dns client handles both unicode names (encoding via idna) and pre-encoded def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
# byteses; however they will obviously end up as separate entries in the cache. We self._dns_client = dns_client
# should pick one form and stick with it. self._cache = cache
cache_entry = cache.get(service_name, None) self._get_time = get_time
if cache_entry:
if all(s.expires > int(clock.time()) for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)
try: @defer.inlineCallbacks
answers, _, _ = yield make_deferred_yieldable( def resolve_service(self, service_name):
dns_client.lookupService(service_name), """Look up a SRV record
)
except DNSNameError: Args:
# TODO: cache this. We can get the SOA out of the exception, and use service_name (bytes): record to look up
# the negative-TTL value.
defer.returnValue([]) Returns:
except DomainError as e: Deferred[list[Server]]:
# We failed to resolve the name (other than a NameError) a list of the SRV records, or an empty list if none found
# Try something in the cache, else rereaise """
cache_entry = cache.get(service_name, None) now = int(self._get_time())
if not isinstance(service_name, bytes):
raise TypeError("%r is not a byte string" % (service_name,))
cache_entry = self._cache.get(service_name, None)
if cache_entry: if cache_entry:
logger.warn( if all(s.expires > now for s in cache_entry):
"Failed to resolve %r, falling back to cache. %r", servers = list(cache_entry)
service_name, e defer.returnValue(servers)
try:
answers, _, _ = yield make_deferred_yieldable(
self._dns_client.lookupService(service_name),
) )
defer.returnValue(list(cache_entry)) except DNSNameError:
else: # TODO: cache this. We can get the SOA out of the exception, and use
raise e # the negative-TTL value.
defer.returnValue([])
except DomainError as e:
# We failed to resolve the name (other than a NameError)
# Try something in the cache, else rereaise
cache_entry = self._cache.get(service_name, None)
if cache_entry:
logger.warn(
"Failed to resolve %r, falling back to cache. %r",
service_name, e
)
defer.returnValue(list(cache_entry))
else:
raise e
if (len(answers) == 1 if (len(answers) == 1
and answers[0].type == dns.SRV and answers[0].type == dns.SRV
and answers[0].payload and answers[0].payload
and answers[0].payload.target == dns.Name(b'.')): and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name) raise ConnectError("Service %s unavailable" % service_name)
servers = [] servers = []
for answer in answers: for answer in answers:
if answer.type != dns.SRV or not answer.payload: if answer.type != dns.SRV or not answer.payload:
continue continue
payload = answer.payload payload = answer.payload
servers.append(Server( servers.append(Server(
host=payload.target.name, host=payload.target.name,
port=payload.port, port=payload.port,
priority=payload.priority, priority=payload.priority,
weight=payload.weight, weight=payload.weight,
expires=int(clock.time()) + answer.ttl, expires=now + answer.ttl,
)) ))
servers.sort() # FIXME: get rid of this (it's broken by the attrs change) self._cache[service_name] = list(servers)
cache[service_name] = list(servers) defer.returnValue(servers)
defer.returnValue(servers)

View File

@ -32,7 +32,7 @@ from twisted.internet import defer, protocol
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.internet.task import _EPSILON, Cooperator from twisted.internet.task import _EPSILON, Cooperator
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, FileBodyProducer, HTTPConnectionPool from twisted.web.client import FileBodyProducer
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
import synapse.metrics import synapse.metrics
@ -44,7 +44,7 @@ from synapse.api.errors import (
RequestSendFailed, RequestSendFailed,
SynapseError, SynapseError,
) )
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.util.async_helpers import timeout_deferred from synapse.util.async_helpers import timeout_deferred
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -66,20 +66,6 @@ else:
MAXINT = sys.maxint MAXINT = sys.maxint
class MatrixFederationEndpointFactory(object):
def __init__(self, hs):
self.reactor = hs.get_reactor()
self.tls_client_options_factory = hs.tls_client_options_factory
def endpointForURI(self, uri):
destination = uri.netloc.decode('ascii')
return matrix_federation_endpoint(
self.reactor, destination, timeout=10,
tls_client_options_factory=self.tls_client_options_factory
)
_next_id = 1 _next_id = 1
@ -187,12 +173,10 @@ class MatrixFederationHttpClient(object):
self.signing_key = hs.config.signing_key[0] self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname self.server_name = hs.hostname
reactor = hs.get_reactor() reactor = hs.get_reactor()
pool = HTTPConnectionPool(reactor)
pool.retryAutomatically = False self.agent = MatrixFederationAgent(
pool.maxPersistentPerHost = 5 hs.get_reactor(),
pool.cachedConnectionTimeout = 2 * 60 hs.tls_client_options_factory,
self.agent = Agent.usingEndpointFactory(
reactor, MatrixFederationEndpointFactory(hs), pool=pool
) )
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._store = hs.get_datastore() self._store = hs.get_datastore()
@ -316,9 +300,9 @@ class MatrixFederationHttpClient(object):
headers_dict[b"Authorization"] = auth_headers headers_dict[b"Authorization"] = auth_headers
logger.info( logger.info(
"{%s} [%s] Sending request: %s %s", "{%s} [%s] Sending request: %s %s; timeout %fs",
request.txn_id, request.destination, request.method, request.txn_id, request.destination, request.method,
url_str, url_str, _sec_timeout,
) )
try: try:
@ -338,12 +322,11 @@ class MatrixFederationHttpClient(object):
reactor=self.hs.get_reactor(), reactor=self.hs.get_reactor(),
) )
response = yield make_deferred_yieldable( response = yield request_deferred
request_deferred,
)
except DNSLookupError as e: except DNSLookupError as e:
raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e) raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e)
except Exception as e: except Exception as e:
logger.info("Failed to send request: %s", e)
raise_from(RequestSendFailed(e, can_retry=True), e) raise_from(RequestSendFailed(e, can_retry=True), e)
logger.info( logger.info(

View File

@ -79,6 +79,10 @@ CONDITIONAL_REQUIREMENTS = {
# ConsentResource uses select_autoescape, which arrived in jinja 2.9 # ConsentResource uses select_autoescape, which arrived in jinja 2.9
"resources.consent": ["Jinja2>=2.9"], "resources.consent": ["Jinja2>=2.9"],
# ACME support is required to provision TLS certificates from authorities
# that use the protocol, such as Let's Encrypt.
"acme": ["txacme>=0.9.2"],
"saml2": ["pysaml2>=4.5.0"], "saml2": ["pysaml2>=4.5.0"],
"url_preview": ["lxml>=3.5.0"], "url_preview": ["lxml>=3.5.0"],
"test": ["mock>=2.0"], "test": ["mock>=2.0"],

View File

@ -416,8 +416,11 @@ class RegisterRestServlet(RestServlet):
) )
# Necessary due to auth checks prior to the threepid being # Necessary due to auth checks prior to the threepid being
# written to the db # written to the db
if is_threepid_reserved(self.hs.config, threepid): if threepid:
yield self.store.upsert_monthly_active_user(registered_user_id) if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid
):
yield self.store.upsert_monthly_active_user(registered_user_id)
# remember that we've now registered that user account, and with # remember that we've now registered that user account, and with
# what user ID (since the user may not have specified) # what user ID (since the user may not have specified)

View File

@ -46,6 +46,7 @@ from synapse.federation.transport.client import TransportLayerClient
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
from synapse.groups.groups_server import GroupsServerHandler from synapse.groups.groups_server import GroupsServerHandler
from synapse.handlers import Handlers from synapse.handlers import Handlers
from synapse.handlers.acme import AcmeHandler
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGenerator from synapse.handlers.auth import AuthHandler, MacaroonGenerator
from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.deactivate_account import DeactivateAccountHandler
@ -129,6 +130,7 @@ class HomeServer(object):
'sync_handler', 'sync_handler',
'typing_handler', 'typing_handler',
'room_list_handler', 'room_list_handler',
'acme_handler',
'auth_handler', 'auth_handler',
'device_handler', 'device_handler',
'e2e_keys_handler', 'e2e_keys_handler',
@ -310,6 +312,9 @@ class HomeServer(object):
def build_e2e_room_keys_handler(self): def build_e2e_room_keys_handler(self):
return E2eRoomKeysHandler(self) return E2eRoomKeysHandler(self)
def build_acme_handler(self):
return AcmeHandler(self)
def build_application_service_api(self): def build_application_service_api(self):
return ApplicationServiceApi(self) return ApplicationServiceApi(self)

View File

@ -26,6 +26,7 @@ from prometheus_client import Histogram
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.util.caches.descriptors import Cache from synapse.util.caches.descriptors import Cache
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
@ -192,6 +193,51 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
# A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = {"user_ips"}
if self.database_engine.can_native_upsert:
# Check ASAP (and then later, every 1s) to see if we have finished
# background updates of tables that aren't safe to update.
self._clock.call_later(
0.0,
run_as_background_process,
"upsert_safety_check",
self._check_safe_to_upsert
)
@defer.inlineCallbacks
def _check_safe_to_upsert(self):
"""
Is it safe to use native UPSERT?
If there are background updates, we will need to wait, as they may be
the addition of indexes that set the UNIQUE constraint that we require.
If the background updates have not completed, wait 15 sec and check again.
"""
updates = yield self._simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
desc="check_background_updates",
)
updates = [x["update_name"] for x in updates]
# The User IPs table in schema #53 was missing a unique index, which we
# run as a background update.
if "user_ips_device_unique_index" not in updates:
self._unsafe_to_upsert_tables.discard("user_ips")
# If there's any tables left to check, reschedule to run.
if self._unsafe_to_upsert_tables:
self._clock.call_later(
15.0,
run_as_background_process,
"upsert_safety_check",
self._check_safe_to_upsert
)
def start_profiling(self): def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec() self._previous_loop_ts = self._clock.time_msec()
@ -494,8 +540,15 @@ class SQLBaseStore(object):
txn.executemany(sql, vals) txn.executemany(sql, vals)
@defer.inlineCallbacks @defer.inlineCallbacks
def _simple_upsert(self, table, keyvalues, values, def _simple_upsert(
insertion_values={}, desc="_simple_upsert", lock=True): self,
table,
keyvalues,
values,
insertion_values={},
desc="_simple_upsert",
lock=True
):
""" """
`lock` should generally be set to True (the default), but can be set `lock` should generally be set to True (the default), but can be set
@ -516,16 +569,21 @@ class SQLBaseStore(object):
inserting inserting
lock (bool): True to lock the table when doing the upsert. lock (bool): True to lock the table when doing the upsert.
Returns: Returns:
Deferred(bool): True if a new entry was created, False if an Deferred(None or bool): Native upserts always return None. Emulated
existing one was updated. upserts return True if a new entry was created, False if an existing
one was updated.
""" """
attempts = 0 attempts = 0
while True: while True:
try: try:
result = yield self.runInteraction( result = yield self.runInteraction(
desc, desc,
self._simple_upsert_txn, table, keyvalues, values, insertion_values, self._simple_upsert_txn,
lock=lock table,
keyvalues,
values,
insertion_values,
lock=lock,
) )
defer.returnValue(result) defer.returnValue(result)
except self.database_engine.module.IntegrityError as e: except self.database_engine.module.IntegrityError as e:
@ -537,12 +595,71 @@ class SQLBaseStore(object):
# presumably we raced with another transaction: let's retry. # presumably we raced with another transaction: let's retry.
logger.warn( logger.warn(
"IntegrityError when upserting into %s; retrying: %s", "%s when upserting into %s; retrying: %s", e.__name__, table, e
table, e
) )
def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={}, def _simple_upsert_txn(
lock=True): self,
txn,
table,
keyvalues,
values,
insertion_values={},
lock=True,
):
"""
Pick the UPSERT method which works best on the platform. Either the
native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
Args:
txn: The transaction to use.
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
insertion_values (dict): additional key/values to use only when
inserting
lock (bool): True to lock the table when doing the upsert.
Returns:
None or bool: Native upserts always return None. Emulated
upserts return True if a new entry was created, False if an existing
one was updated.
"""
if (
self.database_engine.can_native_upsert
and table not in self._unsafe_to_upsert_tables
):
return self._simple_upsert_txn_native_upsert(
txn,
table,
keyvalues,
values,
insertion_values=insertion_values,
)
else:
return self._simple_upsert_txn_emulated(
txn,
table,
keyvalues,
values,
insertion_values=insertion_values,
lock=lock,
)
def _simple_upsert_txn_emulated(
self, txn, table, keyvalues, values, insertion_values={}, lock=True
):
"""
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
insertion_values (dict): additional key/values to use only when
inserting
lock (bool): True to lock the table when doing the upsert.
Returns:
bool: Return True if a new entry was created, False if an existing
one was updated.
"""
# We need to lock the table :(, unless we're *really* careful # We need to lock the table :(, unless we're *really* careful
if lock: if lock:
self.database_engine.lock_table(txn, table) self.database_engine.lock_table(txn, table)
@ -577,12 +694,44 @@ class SQLBaseStore(object):
sql = "INSERT INTO %s (%s) VALUES (%s)" % ( sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table, table,
", ".join(k for k in allvalues), ", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues) ", ".join("?" for _ in allvalues),
) )
txn.execute(sql, list(allvalues.values())) txn.execute(sql, list(allvalues.values()))
# successfully inserted # successfully inserted
return True return True
def _simple_upsert_txn_native_upsert(
self, txn, table, keyvalues, values, insertion_values={}
):
"""
Use the native UPSERT functionality in recent PostgreSQL versions.
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
insertion_values (dict): additional key/values to use only when
inserting
Returns:
None
"""
allvalues = {}
allvalues.update(keyvalues)
allvalues.update(values)
allvalues.update(insertion_values)
sql = (
"INSERT INTO %s (%s) VALUES (%s) "
"ON CONFLICT (%s) DO UPDATE SET %s"
) % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues),
", ".join(k for k in keyvalues),
", ".join(k + "=EXCLUDED." + k for k in values),
)
txn.execute(sql, list(allvalues.values()))
def _simple_select_one(self, table, keyvalues, retcols, def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False, desc="_simple_select_one"): allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to

View File

@ -143,6 +143,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
# If it returns None, then we're processing the last batch # If it returns None, then we're processing the last batch
last = end_last_seen is None last = end_last_seen is None
logger.info(
"Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s",
begin_last_seen, end_last_seen,
)
def remove(txn): def remove(txn):
# This works by looking at all entries in the given time span, and # This works by looking at all entries in the given time span, and
# then for each (user_id, access_token, ip) tuple in that range # then for each (user_id, access_token, ip) tuple in that range
@ -170,7 +175,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
SELECT user_id, access_token, ip SELECT user_id, access_token, ip
FROM user_ips FROM user_ips
WHERE {} WHERE {}
ORDER BY last_seen
) c ) c
INNER JOIN user_ips USING (user_id, access_token, ip) INNER JOIN user_ips USING (user_id, access_token, ip)
GROUP BY user_id, access_token, ip GROUP BY user_id, access_token, ip
@ -253,7 +257,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
) )
def _update_client_ips_batch_txn(self, txn, to_update): def _update_client_ips_batch_txn(self, txn, to_update):
self.database_engine.lock_table(txn, "user_ips") if "user_ips" in self._unsafe_to_upsert_tables or (
not self.database_engine.can_native_upsert
):
self.database_engine.lock_table(txn, "user_ips")
for entry in iteritems(to_update): for entry in iteritems(to_update):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry

View File

@ -18,7 +18,7 @@ import platform
from ._base import IncorrectDatabaseSetup from ._base import IncorrectDatabaseSetup
from .postgres import PostgresEngine from .postgres import PostgresEngine
from .sqlite3 import Sqlite3Engine from .sqlite import Sqlite3Engine
SUPPORTED_MODULE = { SUPPORTED_MODULE = {
"sqlite3": Sqlite3Engine, "sqlite3": Sqlite3Engine,

View File

@ -38,6 +38,13 @@ class PostgresEngine(object):
return sql.replace("?", "%s") return sql.replace("?", "%s")
def on_new_connection(self, db_conn): def on_new_connection(self, db_conn):
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105
self._version = db_conn.server_version
db_conn.set_isolation_level( db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
) )
@ -54,6 +61,13 @@ class PostgresEngine(object):
cursor.close() cursor.close()
@property
def can_native_upsert(self):
"""
Can we use native UPSERTs? This requires PostgreSQL 9.5+.
"""
return self._version >= 90500
def is_deadlock(self, error): def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError): if isinstance(error, self.module.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html # https://www.postgresql.org/docs/current/static/errcodes-appendix.html

View File

@ -15,6 +15,7 @@
import struct import struct
import threading import threading
from sqlite3 import sqlite_version_info
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
@ -30,6 +31,14 @@ class Sqlite3Engine(object):
self._current_state_group_id = None self._current_state_group_id = None
self._current_state_group_id_lock = threading.Lock() self._current_state_group_id_lock = threading.Lock()
@property
def can_native_upsert(self):
"""
Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
more work we haven't done yet to tell what was inserted vs updated.
"""
return sqlite_version_info >= (3, 24, 0)
def check_database(self, txn): def check_database(self, txn):
pass pass

View File

@ -215,7 +215,7 @@ class PusherStore(PusherWorkerStore):
with self._pushers_id_gen.get_next() as stream_id: with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on # no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so _simple_upsert will retry # (app_id, pushkey, user_name) so _simple_upsert will retry
newly_inserted = yield self._simple_upsert( yield self._simple_upsert(
table="pushers", table="pushers",
keyvalues={ keyvalues={
"app_id": app_id, "app_id": app_id,
@ -238,7 +238,12 @@ class PusherStore(PusherWorkerStore):
lock=False, lock=False,
) )
if newly_inserted: user_has_pusher = self.get_if_user_has_pusher.cache.get(
(user_id,), None, update_metrics=False
)
if user_has_pusher is not True:
# invalidate, since we the user might not have had a pusher before
yield self.runInteraction( yield self.runInteraction(
"add_pusher", "add_pusher",
self._invalidate_cache_and_stream, self._invalidate_cache_and_stream,

View File

@ -588,12 +588,12 @@ class RoomMemberStore(RoomMemberWorkerStore):
) )
# We update the local_invites table only if the event is "current", # We update the local_invites table only if the event is "current",
# i.e., its something that has just happened. # i.e., its something that has just happened. If the event is an
# The only current event that can also be an outlier is if its an # outlier it is only current if its an "out of band membership",
# invite that has come in across federation. # like a remote invite or a rejection of a remote invite.
is_new_state = not backfilled and ( is_new_state = not backfilled and (
not event.internal_metadata.is_outlier() not event.internal_metadata.is_outlier()
or event.internal_metadata.is_invite_from_remote() or event.internal_metadata.is_out_of_band_membership()
) )
is_mine = self.hs.is_mine_id(event.state_key) is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine: if is_new_state and is_mine:

View File

@ -168,14 +168,14 @@ class UserDirectoryStore(SQLBaseStore):
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
# We weight the localpart most highly, then display name and finally # We weight the localpart most highly, then display name and finally
# server name # server name
if new_entry: if self.database_engine.can_native_upsert:
sql = """ sql = """
INSERT INTO user_directory_search(user_id, vector) INSERT INTO user_directory_search(user_id, vector)
VALUES (?, VALUES (?,
setweight(to_tsvector('english', ?), 'A') setweight(to_tsvector('english', ?), 'A')
|| setweight(to_tsvector('english', ?), 'D') || setweight(to_tsvector('english', ?), 'D')
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B') || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
) ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
""" """
txn.execute( txn.execute(
sql, sql,
@ -185,20 +185,45 @@ class UserDirectoryStore(SQLBaseStore):
) )
) )
else: else:
sql = """ # TODO: Remove this code after we've bumped the minimum version
UPDATE user_directory_search # of postgres to always support upserts, so we can get rid of
SET vector = setweight(to_tsvector('english', ?), 'A') # `new_entry` usage
|| setweight(to_tsvector('english', ?), 'D') if new_entry is True:
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B') sql = """
WHERE user_id = ? INSERT INTO user_directory_search(user_id, vector)
""" VALUES (?,
txn.execute( setweight(to_tsvector('english', ?), 'A')
sql, || setweight(to_tsvector('english', ?), 'D')
( || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
get_localpart_from_id(user_id), get_domain_from_id(user_id), )
display_name, user_id, """
txn.execute(
sql,
(
user_id, get_localpart_from_id(user_id),
get_domain_from_id(user_id), display_name,
)
)
elif new_entry is False:
sql = """
UPDATE user_directory_search
SET vector = setweight(to_tsvector('english', ?), 'A')
|| setweight(to_tsvector('english', ?), 'D')
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
WHERE user_id = ?
"""
txn.execute(
sql,
(
get_localpart_from_id(user_id),
get_domain_from_id(user_id),
display_name, user_id,
)
)
else:
raise RuntimeError(
"upsert returned None when 'can_native_upsert' is False"
) )
)
elif isinstance(self.database_engine, Sqlite3Engine): elif isinstance(self.database_engine, Sqlite3Engine):
value = "%s %s" % (user_id, display_name,) if display_name else user_id value = "%s %s" % (user_id, display_name,) if display_name else user_id
self._simple_upsert_txn( self._simple_upsert_txn(

View File

@ -0,0 +1,315 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from mock import Mock
import treq
from twisted.internet import defer
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.test.ssl_helpers import ServerTLSContext
from twisted.web.http import HTTPChannel
from synapse.crypto.context_factory import ClientTLSOptionsFactory
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.http.federation.srv_resolver import Server
from synapse.util.logcontext import LoggingContext
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.unittest import TestCase
logger = logging.getLogger(__name__)
class MatrixFederationAgentTests(TestCase):
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
self.mock_resolver = Mock()
self.agent = MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=ClientTLSOptionsFactory(None),
_srv_resolver=self.mock_resolver,
)
def _make_connection(self, client_factory, expected_sni):
"""Builds a test server, and completes the outgoing client connection
Returns:
HTTPChannel: the test server
"""
# build the test server
server_tls_protocol = _build_test_server()
# now, tell the client protocol factory to build the client protocol (it will be a
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
# HTTP11ClientProtocol) and wire the output of said protocol up to the server via
# a FakeTransport.
#
# Normally this would be done by the TCP socket code in Twisted, but we are
# stubbing that out here.
client_protocol = client_factory.buildProtocol(None)
client_protocol.makeConnection(FakeTransport(server_tls_protocol, self.reactor))
# tell the server tls protocol to send its stuff back to the client, too
server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor))
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
# check the SNI
server_name = server_tls_protocol._tlsConnection.get_servername()
self.assertEqual(
server_name,
expected_sni,
"Expected SNI %s but got %s" % (expected_sni, server_name),
)
# fish the test server back out of the server-side TLS protocol.
return server_tls_protocol.wrappedProtocol
@defer.inlineCallbacks
def _make_get_request(self, uri):
"""
Sends a simple GET request via the agent, and checks its logcontext management
"""
with LoggingContext("one") as context:
fetch_d = self.agent.request(b'GET', uri)
# Nothing happened yet
self.assertNoResult(fetch_d)
# should have reset logcontext to the sentinel
_check_logcontext(LoggingContext.sentinel)
try:
fetch_res = yield fetch_d
defer.returnValue(fetch_res)
finally:
_check_logcontext(context)
def test_get(self):
"""
happy-path test of a GET request with an explicit port
"""
self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar")
# Nothing happened yet
self.assertNoResult(test_d)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 8448)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory,
expected_sni=b"testserv",
)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
content = request.content.read()
self.assertEqual(content, b'')
# Deferred is still without a result
self.assertNoResult(test_d)
# send the headers
request.responseHeaders.setRawHeaders(b'Content-Type', [b'application/json'])
request.write('')
self.reactor.pump((0.1,))
response = self.successResultOf(test_d)
# that should give us a Response object
self.assertEqual(response.code, 200)
# Send the body
request.write('{ "a": 1 }'.encode('ascii'))
request.finish()
self.reactor.pump((0.1,))
# check it can be read
json = self.successResultOf(treq.json_content(response))
self.assertEqual(json, {"a": 1})
def test_get_ip_address(self):
"""
Test the behaviour when the server name contains an explicit IP (with no port)
"""
# the SRV lookup will return an empty list (XXX: why do we even do an SRV lookup?)
self.mock_resolver.resolve_service.side_effect = lambda _: []
# then there will be a getaddrinfo on the IP
self.reactor.lookups["1.2.3.4"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar")
# Nothing happened yet
self.assertNoResult(test_d)
self.mock_resolver.resolve_service.assert_called_once_with(
b"_matrix._tcp.1.2.3.4",
)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 8448)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory,
expected_sni=None,
)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
# finish the request
request.finish()
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_get_hostname_no_srv(self):
"""
Test the behaviour when the server name has no port, and no SRV record
"""
self.mock_resolver.resolve_service.side_effect = lambda _: []
self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
# Nothing happened yet
self.assertNoResult(test_d)
self.mock_resolver.resolve_service.assert_called_once_with(
b"_matrix._tcp.testserv",
)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 8448)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory,
expected_sni=b'testserv',
)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
# finish the request
request.finish()
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_get_hostname_srv(self):
"""
Test the behaviour when there is a single SRV record
"""
self.mock_resolver.resolve_service.side_effect = lambda _: [
Server(host="srvtarget", port=8443)
]
self.reactor.lookups["srvtarget"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
# Nothing happened yet
self.assertNoResult(test_d)
self.mock_resolver.resolve_service.assert_called_once_with(
b"_matrix._tcp.testserv",
)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 8443)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory,
expected_sni=b'testserv',
)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
# finish the request
request.finish()
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def _check_logcontext(context):
current = LoggingContext.current_context()
if current is not context:
raise AssertionError(
"Expected logcontext %s but was %s" % (context, current),
)
def _build_test_server():
"""Construct a test server
This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
Returns:
TLSMemoryBIOProtocol
"""
server_factory = Factory.forProtocol(HTTPChannel)
# Request.finish expects the factory to have a 'log' method.
server_factory.log = _log_request
server_tls_factory = TLSMemoryBIOFactory(
ServerTLSContext(), isClient=False, wrappedFactory=server_factory,
)
return server_tls_factory.buildProtocol(None)
def _log_request(request):
"""Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request)

View File

@ -21,7 +21,7 @@ from twisted.internet.defer import Deferred
from twisted.internet.error import ConnectError from twisted.internet.error import ConnectError
from twisted.names import dns, error from twisted.names import dns, error
from synapse.http.federation.srv_resolver import resolve_service from synapse.http.federation.srv_resolver import SrvResolver
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from tests import unittest from tests import unittest
@ -43,13 +43,13 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock.lookupService.return_value = result_deferred dns_client_mock.lookupService.return_value = result_deferred
cache = {} cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_lookup(): def do_lookup():
with LoggingContext("one") as ctx: with LoggingContext("one") as ctx:
resolve_d = resolve_service( resolve_d = resolver.resolve_service(service_name)
service_name, dns_client=dns_client_mock, cache=cache
)
self.assertNoResult(resolve_d) self.assertNoResult(resolve_d)
@ -83,16 +83,15 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock = Mock() dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
service_name = "test_service.example.com" service_name = b"test_service.example.com"
entry = Mock(spec_set=["expires"]) entry = Mock(spec_set=["expires"])
entry.expires = 0 entry.expires = 0
cache = {service_name: [entry]} cache = {service_name: [entry]}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
servers = yield resolve_service( servers = yield resolver.resolve_service(service_name)
service_name, dns_client=dns_client_mock, cache=cache
)
dns_client_mock.lookupService.assert_called_once_with(service_name) dns_client_mock.lookupService.assert_called_once_with(service_name)
@ -106,17 +105,18 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock = Mock(spec_set=['lookupService']) dns_client_mock = Mock(spec_set=['lookupService'])
dns_client_mock.lookupService = Mock(spec_set=[]) dns_client_mock.lookupService = Mock(spec_set=[])
service_name = "test_service.example.com" service_name = b"test_service.example.com"
entry = Mock(spec_set=["expires"]) entry = Mock(spec_set=["expires"])
entry.expires = 999999999 entry.expires = 999999999
cache = {service_name: [entry]} cache = {service_name: [entry]}
resolver = SrvResolver(
servers = yield resolve_service( dns_client=dns_client_mock, cache=cache, get_time=clock.time,
service_name, dns_client=dns_client_mock, cache=cache, clock=clock
) )
servers = yield resolver.resolve_service(service_name)
self.assertFalse(dns_client_mock.lookupService.called) self.assertFalse(dns_client_mock.lookupService.called)
self.assertEquals(len(servers), 1) self.assertEquals(len(servers), 1)
@ -128,12 +128,13 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
service_name = "test_service.example.com" service_name = b"test_service.example.com"
cache = {} cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
with self.assertRaises(error.DNSServerError): with self.assertRaises(error.DNSServerError):
yield resolve_service(service_name, dns_client=dns_client_mock, cache=cache) yield resolver.resolve_service(service_name)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_name_error(self): def test_name_error(self):
@ -141,13 +142,12 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError()) dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
service_name = "test_service.example.com" service_name = b"test_service.example.com"
cache = {} cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
servers = yield resolve_service( servers = yield resolver.resolve_service(service_name)
service_name, dns_client=dns_client_mock, cache=cache
)
self.assertEquals(len(servers), 0) self.assertEquals(len(servers), 0)
self.assertEquals(len(cache), 0) self.assertEquals(len(cache), 0)
@ -162,10 +162,9 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock = Mock() dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = lookup_deferred dns_client_mock.lookupService.return_value = lookup_deferred
cache = {} cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
resolve_d = resolve_service( resolve_d = resolver.resolve_service(service_name)
service_name, dns_client=dns_client_mock, cache=cache
)
self.assertNoResult(resolve_d) self.assertNoResult(resolve_d)
# returning a single "." should make the lookup fail with a ConenctError # returning a single "." should make the lookup fail with a ConenctError
@ -187,10 +186,9 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock = Mock() dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = lookup_deferred dns_client_mock.lookupService.return_value = lookup_deferred
cache = {} cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
resolve_d = resolve_service( resolve_d = resolver.resolve_service(service_name)
service_name, dns_client=dns_client_mock, cache=cache
)
self.assertNoResult(resolve_d) self.assertNoResult(resolve_d)
lookup_deferred.callback(( lookup_deferred.callback((

View File

@ -15,6 +15,7 @@
from mock import Mock from mock import Mock
from twisted.internet import defer
from twisted.internet.defer import TimeoutError from twisted.internet.defer import TimeoutError
from twisted.internet.error import ConnectingCancelledError, DNSLookupError from twisted.internet.error import ConnectingCancelledError, DNSLookupError
from twisted.test.proto_helpers import StringTransport from twisted.test.proto_helpers import StringTransport
@ -26,11 +27,20 @@ from synapse.http.matrixfederationclient import (
MatrixFederationHttpClient, MatrixFederationHttpClient,
MatrixFederationRequest, MatrixFederationRequest,
) )
from synapse.util.logcontext import LoggingContext
from tests.server import FakeTransport from tests.server import FakeTransport
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
def check_logcontext(context):
current = LoggingContext.current_context()
if current is not context:
raise AssertionError(
"Expected logcontext %s but was %s" % (context, current),
)
class FederationClientTests(HomeserverTestCase): class FederationClientTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
@ -43,6 +53,70 @@ class FederationClientTests(HomeserverTestCase):
self.cl = MatrixFederationHttpClient(self.hs) self.cl = MatrixFederationHttpClient(self.hs)
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
def test_client_get(self):
"""
happy-path test of a GET request
"""
@defer.inlineCallbacks
def do_request():
with LoggingContext("one") as context:
fetch_d = self.cl.get_json("testserv:8008", "foo/bar")
# Nothing happened yet
self.assertNoResult(fetch_d)
# should have reset logcontext to the sentinel
check_logcontext(LoggingContext.sentinel)
try:
fetch_res = yield fetch_d
defer.returnValue(fetch_res)
finally:
check_logcontext(context)
test_d = do_request()
self.pump()
# Nothing happened yet
self.assertNoResult(test_d)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 8008)
# complete the connection and wire it up to a fake transport
protocol = factory.buildProtocol(None)
transport = StringTransport()
protocol.makeConnection(transport)
# that should have made it send the request to the transport
self.assertRegex(transport.value(), b"^GET /foo/bar")
# Deferred is still without a result
self.assertNoResult(test_d)
# Send it the HTTP response
res_json = '{ "a": 1 }'.encode('ascii')
protocol.dataReceived(
b"HTTP/1.1 200 OK\r\n"
b"Server: Fake\r\n"
b"Content-Type: application/json\r\n"
b"Content-Length: %i\r\n"
b"\r\n"
b"%s" % (len(res_json), res_json)
)
self.pump()
res = self.successResultOf(test_d)
# check the response is as expected
self.assertEqual(res, {"a": 1})
def test_dns_error(self): def test_dns_error(self):
""" """
If the DNS lookup returns an error, it will bubble up. If the DNS lookup returns an error, it will bubble up.
@ -54,6 +128,28 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, DNSLookupError) self.assertIsInstance(f.value.inner_exception, DNSLookupError)
def test_client_connection_refused(self):
d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
self.pump()
# Nothing happened yet
self.assertNoResult(d)
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 8008)
e = Exception("go away")
factory.clientConnectionFailed(None, e)
self.pump(0.5)
f = self.failureResultOf(d)
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIs(f.value.inner_exception, e)
def test_client_never_connect(self): def test_client_never_connect(self):
""" """
If the HTTP request is not connected and is timed out, it'll give a If the HTTP request is not connected and is timed out, it'll give a

View File

@ -1,4 +1,5 @@
import json import json
import logging
from io import BytesIO from io import BytesIO
from six import text_type from six import text_type
@ -22,6 +23,8 @@ from synapse.util import Clock
from tests.utils import setup_test_homeserver as _sth from tests.utils import setup_test_homeserver as _sth
logger = logging.getLogger(__name__)
class TimedOutException(Exception): class TimedOutException(Exception):
""" """
@ -339,7 +342,7 @@ def get_clock():
return (clock, hs_clock) return (clock, hs_clock)
@attr.s @attr.s(cmp=False)
class FakeTransport(object): class FakeTransport(object):
""" """
A twisted.internet.interfaces.ITransport implementation which sends all its data A twisted.internet.interfaces.ITransport implementation which sends all its data
@ -414,6 +417,11 @@ class FakeTransport(object):
self.buffer = self.buffer + byt self.buffer = self.buffer + byt
def _write(): def _write():
if not self.buffer:
# nothing to do. Don't write empty buffers: it upsets the
# TLSMemoryBIOProtocol
return
if getattr(self.other, "transport") is not None: if getattr(self.other, "transport") is not None:
self.other.dataReceived(self.buffer) self.other.dataReceived(self.buffer)
self.buffer = b"" self.buffer = b""
@ -421,7 +429,10 @@ class FakeTransport(object):
self._reactor.callLater(0.0, _write) self._reactor.callLater(0.0, _write)
_write() # always actually do the write asynchronously. Some protocols (notably the
# TLSMemoryBIOProtocol) get very confused if a read comes back while they are
# still doing a write. Doing a callLater here breaks the cycle.
self._reactor.callLater(0.0, _write)
def writeSequence(self, seq): def writeSequence(self, seq):
for x in seq: for x in seq:

View File

@ -49,6 +49,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.db_pool.runWithConnection = runWithConnection self.db_pool.runWithConnection = runWithConnection
config = Mock() config = Mock()
config._enable_native_upserts = False
config.event_cache_size = 1 config.event_cache_size = 1
config.database_config = {"name": "sqlite3"} config.database_config = {"name": "sqlite3"}
hs = TestHomeServer( hs = TestHomeServer(

View File

@ -19,7 +19,7 @@ from six import StringIO
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
@ -30,12 +30,18 @@ from synapse.util import Clock
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
from tests import unittest from tests import unittest
from tests.server import FakeTransport, make_request, render, setup_test_homeserver from tests.server import (
FakeTransport,
ThreadedMemoryReactorClock,
make_request,
render,
setup_test_homeserver,
)
class JsonResourceTests(unittest.TestCase): class JsonResourceTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.reactor = MemoryReactorClock() self.reactor = ThreadedMemoryReactorClock()
self.hs_clock = Clock(self.reactor) self.hs_clock = Clock(self.reactor)
self.homeserver = setup_test_homeserver( self.homeserver = setup_test_homeserver(
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor

View File

@ -96,7 +96,7 @@ class TestCase(unittest.TestCase):
method = getattr(self, methodName) method = getattr(self, methodName)
level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR)) level = getattr(method, "loglevel", getattr(self, "loglevel", logging.WARNING))
@around(self) @around(self)
def setUp(orig): def setUp(orig):
@ -333,7 +333,15 @@ class HomeserverTestCase(TestCase):
""" """
kwargs = dict(kwargs) kwargs = dict(kwargs)
kwargs.update(self._hs_args) kwargs.update(self._hs_args)
return setup_test_homeserver(self.addCleanup, *args, **kwargs) hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
# Run the database background updates.
if hasattr(stor, "do_next_background_update"):
while not self.get_success(stor.has_completed_background_updates()):
self.get_success(stor.do_next_background_update(1))
return hs
def pump(self, by=0.0): def pump(self, by=0.0):
""" """

View File

@ -154,7 +154,9 @@ def default_config(name):
config.update_user_directory = False config.update_user_directory = False
def is_threepid_reserved(threepid): def is_threepid_reserved(threepid):
return ServerConfig.is_threepid_reserved(config, threepid) return ServerConfig.is_threepid_reserved(
config.mau_limits_reserved_threepids, threepid
)
config.is_threepid_reserved.side_effect = is_threepid_reserved config.is_threepid_reserved.side_effect = is_threepid_reserved

View File

@ -149,4 +149,5 @@ deps =
codecov codecov
commands = commands =
coverage combine coverage combine
coverage xml
codecov -X gcov codecov -X gcov