mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-12-22 06:55:00 -05:00
Merge branch 'erikj/get_pdu_versions' into erikj/require_format_version
This commit is contained in:
commit
2a8edbaf74
1
changelog.d/4384.feature
Normal file
1
changelog.d/4384.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Synapse can now automatically provision TLS certificates via ACME (the protocol used by CAs like Let's Encrypt).
|
1
changelog.d/4428.misc
Normal file
1
changelog.d/4428.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Move SRV logic into the Agent layer
|
1
changelog.d/4432.misc
Normal file
1
changelog.d/4432.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Apply a unique index to the user_ips table, preventing duplicates.
|
1
changelog.d/4433.misc
Normal file
1
changelog.d/4433.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
debian package: symlink to explicit python version
|
1
changelog.d/4445.feature
Normal file
1
changelog.d/4445.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add a metric for tracking event stream position of the user directory.
|
1
changelog.d/4448.misc
Normal file
1
changelog.d/4448.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add infrastructure to support different event formats
|
15
debian/build_virtualenv
vendored
15
debian/build_virtualenv
vendored
@ -6,7 +6,16 @@
|
|||||||
set -e
|
set -e
|
||||||
|
|
||||||
export DH_VIRTUALENV_INSTALL_ROOT=/opt/venvs
|
export DH_VIRTUALENV_INSTALL_ROOT=/opt/venvs
|
||||||
SNAKE=/usr/bin/python3
|
|
||||||
|
# make sure that the virtualenv links to the specific version of python, by
|
||||||
|
# dereferencing the python3 symlink.
|
||||||
|
#
|
||||||
|
# Otherwise, if somebody tries to install (say) the stretch package on buster,
|
||||||
|
# they will get a confusing error about "No module named 'synapse'", because
|
||||||
|
# python won't look in the right directory. At least this way, the error will
|
||||||
|
# be a *bit* more obvious.
|
||||||
|
#
|
||||||
|
SNAKE=`readlink -e /usr/bin/python3`
|
||||||
|
|
||||||
# try to set the CFLAGS so any compiled C extensions are compiled with the most
|
# try to set the CFLAGS so any compiled C extensions are compiled with the most
|
||||||
# generic as possible x64 instructions, so that compiling it on a new Intel chip
|
# generic as possible x64 instructions, so that compiling it on a new Intel chip
|
||||||
@ -46,3 +55,7 @@ cp -r tests "$tmpdir"
|
|||||||
PYTHONPATH="$tmpdir" \
|
PYTHONPATH="$tmpdir" \
|
||||||
debian/matrix-synapse-py3/opt/venvs/matrix-synapse/bin/python \
|
debian/matrix-synapse-py3/opt/venvs/matrix-synapse/bin/python \
|
||||||
-B -m twisted.trial --reporter=text -j2 tests
|
-B -m twisted.trial --reporter=text -j2 tests
|
||||||
|
|
||||||
|
# add a dependency on the right version of python to substvars.
|
||||||
|
PYPKG=`basename $SNAKE`
|
||||||
|
echo "synapse:pydepends=$PYPKG" >> debian/matrix-synapse-py3.substvars
|
||||||
|
2
debian/control
vendored
2
debian/control
vendored
@ -27,8 +27,8 @@ Depends:
|
|||||||
adduser,
|
adduser,
|
||||||
debconf,
|
debconf,
|
||||||
python3-distutils|libpython3-stdlib (<< 3.6),
|
python3-distutils|libpython3-stdlib (<< 3.6),
|
||||||
python3,
|
|
||||||
${misc:Depends},
|
${misc:Depends},
|
||||||
|
${synapse:pydepends},
|
||||||
# some of our scripts use perl, but none of them are important,
|
# some of our scripts use perl, but none of them are important,
|
||||||
# so we put perl:Depends in Suggests rather than Depends.
|
# so we put perl:Depends in Suggests rather than Depends.
|
||||||
Suggests:
|
Suggests:
|
||||||
|
@ -10,12 +10,12 @@
|
|||||||
# can be passed on the commandline for debugging.
|
# can be passed on the commandline for debugging.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
DISTS = (
|
DISTS = (
|
||||||
"debian:stretch",
|
"debian:stretch",
|
||||||
|
@ -13,10 +13,12 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
from six import iteritems
|
from six import iteritems
|
||||||
|
|
||||||
@ -324,17 +326,12 @@ def setup(config_options):
|
|||||||
|
|
||||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||||
|
|
||||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
|
||||||
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
|
||||||
|
|
||||||
database_engine = create_engine(config.database_config)
|
database_engine = create_engine(config.database_config)
|
||||||
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
|
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
|
||||||
|
|
||||||
hs = SynapseHomeServer(
|
hs = SynapseHomeServer(
|
||||||
config.server_name,
|
config.server_name,
|
||||||
db_config=config.database_config,
|
db_config=config.database_config,
|
||||||
tls_server_context_factory=tls_server_context_factory,
|
|
||||||
tls_client_options_factory=tls_client_options_factory,
|
|
||||||
config=config,
|
config=config,
|
||||||
version_string="Synapse/" + get_version_string(synapse),
|
version_string="Synapse/" + get_version_string(synapse),
|
||||||
database_engine=database_engine,
|
database_engine=database_engine,
|
||||||
@ -361,12 +358,53 @@ def setup(config_options):
|
|||||||
logger.info("Database prepared in %s.", config.database_config['name'])
|
logger.info("Database prepared in %s.", config.database_config['name'])
|
||||||
|
|
||||||
hs.setup()
|
hs.setup()
|
||||||
hs.start_listening()
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def start():
|
def start():
|
||||||
|
try:
|
||||||
|
# Check if the certificate is still valid.
|
||||||
|
cert_days_remaining = hs.config.is_disk_cert_valid()
|
||||||
|
|
||||||
|
if hs.config.acme_enabled:
|
||||||
|
# If ACME is enabled, we might need to provision a certificate
|
||||||
|
# before starting.
|
||||||
|
acme = hs.get_acme_handler()
|
||||||
|
|
||||||
|
# Start up the webservices which we will respond to ACME
|
||||||
|
# challenges with.
|
||||||
|
yield acme.start_listening()
|
||||||
|
|
||||||
|
# We want to reprovision if cert_days_remaining is None (meaning no
|
||||||
|
# certificate exists), or the days remaining number it returns
|
||||||
|
# is less than our re-registration threshold.
|
||||||
|
if (cert_days_remaining is None) or (
|
||||||
|
not cert_days_remaining > hs.config.acme_reprovision_threshold
|
||||||
|
):
|
||||||
|
yield acme.provision_certificate()
|
||||||
|
|
||||||
|
# Read the certificate from disk and build the context factories for
|
||||||
|
# TLS.
|
||||||
|
hs.config.read_certificate_from_disk()
|
||||||
|
hs.tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||||
|
hs.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
|
||||||
|
config
|
||||||
|
)
|
||||||
|
|
||||||
|
# It is now safe to start your Synapse.
|
||||||
|
hs.start_listening()
|
||||||
hs.get_pusherpool().start()
|
hs.get_pusherpool().start()
|
||||||
hs.get_datastore().start_profiling()
|
hs.get_datastore().start_profiling()
|
||||||
hs.get_datastore().start_doing_background_updates()
|
hs.get_datastore().start_doing_background_updates()
|
||||||
|
except Exception as e:
|
||||||
|
# If a DeferredList failed (like in listening on the ACME listener),
|
||||||
|
# we need to print the subfailure explicitly.
|
||||||
|
if isinstance(e, defer.FirstError):
|
||||||
|
e.subFailure.printTraceback(sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Something else went wrong when starting. Print it and bail out.
|
||||||
|
traceback.print_exc(file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
reactor.callWhenRunning(start)
|
reactor.callWhenRunning(start)
|
||||||
|
|
||||||
|
@ -367,7 +367,7 @@ class Config(object):
|
|||||||
if not keys_directory:
|
if not keys_directory:
|
||||||
keys_directory = os.path.dirname(config_files[-1])
|
keys_directory = os.path.dirname(config_files[-1])
|
||||||
|
|
||||||
config_dir_path = os.path.abspath(keys_directory)
|
self.config_dir_path = os.path.abspath(keys_directory)
|
||||||
|
|
||||||
specified_config = {}
|
specified_config = {}
|
||||||
for config_file in config_files:
|
for config_file in config_files:
|
||||||
@ -379,7 +379,7 @@ class Config(object):
|
|||||||
|
|
||||||
server_name = specified_config["server_name"]
|
server_name = specified_config["server_name"]
|
||||||
config_string = self.generate_config(
|
config_string = self.generate_config(
|
||||||
config_dir_path=config_dir_path,
|
config_dir_path=self.config_dir_path,
|
||||||
data_dir_path=os.getcwd(),
|
data_dir_path=os.getcwd(),
|
||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
generate_secrets=False,
|
generate_secrets=False,
|
||||||
|
@ -13,45 +13,38 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
|
||||||
from unpaddedbase64 import encode_base64
|
from unpaddedbase64 import encode_base64
|
||||||
|
|
||||||
from OpenSSL import crypto
|
from OpenSSL import crypto
|
||||||
|
|
||||||
from ._base import Config
|
from synapse.config._base import Config
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
class TlsConfig(Config):
|
class TlsConfig(Config):
|
||||||
def read_config(self, config):
|
def read_config(self, config):
|
||||||
self.tls_certificate = self.read_tls_certificate(
|
|
||||||
config.get("tls_certificate_path")
|
|
||||||
)
|
|
||||||
self.tls_certificate_file = config.get("tls_certificate_path")
|
|
||||||
|
|
||||||
|
acme_config = config.get("acme", {})
|
||||||
|
self.acme_enabled = acme_config.get("enabled", False)
|
||||||
|
self.acme_url = acme_config.get(
|
||||||
|
"url", "https://acme-v01.api.letsencrypt.org/directory"
|
||||||
|
)
|
||||||
|
self.acme_port = acme_config.get("port", 8449)
|
||||||
|
self.acme_bind_addresses = acme_config.get("bind_addresses", ["127.0.0.1"])
|
||||||
|
self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30)
|
||||||
|
|
||||||
|
self.tls_certificate_file = os.path.abspath(config.get("tls_certificate_path"))
|
||||||
|
self.tls_private_key_file = os.path.abspath(config.get("tls_private_key_path"))
|
||||||
|
self._original_tls_fingerprints = config["tls_fingerprints"]
|
||||||
|
self.tls_fingerprints = list(self._original_tls_fingerprints)
|
||||||
self.no_tls = config.get("no_tls", False)
|
self.no_tls = config.get("no_tls", False)
|
||||||
|
|
||||||
if self.no_tls:
|
|
||||||
self.tls_private_key = None
|
|
||||||
else:
|
|
||||||
self.tls_private_key = self.read_tls_private_key(
|
|
||||||
config.get("tls_private_key_path")
|
|
||||||
)
|
|
||||||
|
|
||||||
self.tls_fingerprints = config["tls_fingerprints"]
|
|
||||||
|
|
||||||
# Check that our own certificate is included in the list of fingerprints
|
|
||||||
# and include it if it is not.
|
|
||||||
x509_certificate_bytes = crypto.dump_certificate(
|
|
||||||
crypto.FILETYPE_ASN1,
|
|
||||||
self.tls_certificate
|
|
||||||
)
|
|
||||||
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
|
|
||||||
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
|
|
||||||
if sha256_fingerprint not in sha256_fingerprints:
|
|
||||||
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
|
|
||||||
|
|
||||||
# This config option applies to non-federation HTTP clients
|
# This config option applies to non-federation HTTP clients
|
||||||
# (e.g. for talking to recaptcha, identity servers, and such)
|
# (e.g. for talking to recaptcha, identity servers, and such)
|
||||||
# It should never be used in production, and is intended for
|
# It should never be used in production, and is intended for
|
||||||
@ -60,13 +53,70 @@ class TlsConfig(Config):
|
|||||||
"use_insecure_ssl_client_just_for_testing_do_not_use"
|
"use_insecure_ssl_client_just_for_testing_do_not_use"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.tls_certificate = None
|
||||||
|
self.tls_private_key = None
|
||||||
|
|
||||||
|
def is_disk_cert_valid(self):
|
||||||
|
"""
|
||||||
|
Is the certificate we have on disk valid, and if so, for how long?
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Days remaining of certificate validity.
|
||||||
|
None: No certificate exists.
|
||||||
|
"""
|
||||||
|
if not os.path.exists(self.tls_certificate_file):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.tls_certificate_file, 'rb') as f:
|
||||||
|
cert_pem = f.read()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to read existing certificate off disk!")
|
||||||
|
raise
|
||||||
|
|
||||||
|
try:
|
||||||
|
tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to parse existing certificate off disk!")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# YYYYMMDDhhmmssZ -- in UTC
|
||||||
|
expires_on = datetime.strptime(
|
||||||
|
tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ"
|
||||||
|
)
|
||||||
|
now = datetime.utcnow()
|
||||||
|
days_remaining = (expires_on - now).days
|
||||||
|
return days_remaining
|
||||||
|
|
||||||
|
def read_certificate_from_disk(self):
|
||||||
|
"""
|
||||||
|
Read the certificates from disk.
|
||||||
|
"""
|
||||||
|
self.tls_certificate = self.read_tls_certificate(self.tls_certificate_file)
|
||||||
|
|
||||||
|
if not self.no_tls:
|
||||||
|
self.tls_private_key = self.read_tls_private_key(self.tls_private_key_file)
|
||||||
|
|
||||||
|
self.tls_fingerprints = list(self._original_tls_fingerprints)
|
||||||
|
|
||||||
|
# Check that our own certificate is included in the list of fingerprints
|
||||||
|
# and include it if it is not.
|
||||||
|
x509_certificate_bytes = crypto.dump_certificate(
|
||||||
|
crypto.FILETYPE_ASN1, self.tls_certificate
|
||||||
|
)
|
||||||
|
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
|
||||||
|
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
|
||||||
|
if sha256_fingerprint not in sha256_fingerprints:
|
||||||
|
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
|
||||||
|
|
||||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||||
base_key_name = os.path.join(config_dir_path, server_name)
|
base_key_name = os.path.join(config_dir_path, server_name)
|
||||||
|
|
||||||
tls_certificate_path = base_key_name + ".tls.crt"
|
tls_certificate_path = base_key_name + ".tls.crt"
|
||||||
tls_private_key_path = base_key_name + ".tls.key"
|
tls_private_key_path = base_key_name + ".tls.key"
|
||||||
|
|
||||||
return """\
|
return (
|
||||||
|
"""\
|
||||||
# PEM encoded X509 certificate for TLS.
|
# PEM encoded X509 certificate for TLS.
|
||||||
# You can replace the self-signed certificate that synapse
|
# You can replace the self-signed certificate that synapse
|
||||||
# autogenerates on launch with your own SSL certificate + key pair
|
# autogenerates on launch with your own SSL certificate + key pair
|
||||||
@ -107,7 +157,24 @@ class TlsConfig(Config):
|
|||||||
#
|
#
|
||||||
tls_fingerprints: []
|
tls_fingerprints: []
|
||||||
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
|
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
|
||||||
""" % locals()
|
|
||||||
|
## Support for ACME certificate auto-provisioning.
|
||||||
|
# acme:
|
||||||
|
# enabled: false
|
||||||
|
## ACME path.
|
||||||
|
## If you only want to test, use the staging url:
|
||||||
|
## https://acme-staging.api.letsencrypt.org/directory
|
||||||
|
# url: 'https://acme-v01.api.letsencrypt.org/directory'
|
||||||
|
## Port number (to listen for the HTTP-01 challenge).
|
||||||
|
## Using port 80 requires utilising something like authbind, or proxying to it.
|
||||||
|
# port: 8449
|
||||||
|
## Hosts to bind to.
|
||||||
|
# bind_addresses: ['127.0.0.1']
|
||||||
|
## How many days remaining on a certificate before it is renewed.
|
||||||
|
# reprovision_threshold: 30
|
||||||
|
"""
|
||||||
|
% locals()
|
||||||
|
)
|
||||||
|
|
||||||
def read_tls_certificate(self, cert_path):
|
def read_tls_certificate(self, cert_path):
|
||||||
cert_pem = self.read_file(cert_path, "tls_certificate")
|
cert_pem = self.read_file(cert_path, "tls_certificate")
|
||||||
|
@ -43,8 +43,8 @@ class FederationBase(object):
|
|||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
|
def _check_sigs_and_hash_and_fetch(self, origin, pdus, room_version,
|
||||||
include_none=False):
|
outlier=False, include_none=False):
|
||||||
"""Takes a list of PDUs and checks the signatures and hashs of each
|
"""Takes a list of PDUs and checks the signatures and hashs of each
|
||||||
one. If a PDU fails its signature check then we check if we have it in
|
one. If a PDU fails its signature check then we check if we have it in
|
||||||
the database and if not then request if from the originating server of
|
the database and if not then request if from the originating server of
|
||||||
@ -56,8 +56,12 @@ class FederationBase(object):
|
|||||||
a new list.
|
a new list.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
origin (str)
|
||||||
pdu (list)
|
pdu (list)
|
||||||
outlier (bool)
|
room_version (str)
|
||||||
|
outlier (bool): Whether the events are outliers or not
|
||||||
|
include_none (str): Whether to include None in the returned list
|
||||||
|
for events that have failed their checks
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred : A list of PDUs that have valid signatures and hashes.
|
Deferred : A list of PDUs that have valid signatures and hashes.
|
||||||
@ -84,6 +88,7 @@ class FederationBase(object):
|
|||||||
res = yield self.get_pdu(
|
res = yield self.get_pdu(
|
||||||
destinations=[pdu.origin],
|
destinations=[pdu.origin],
|
||||||
event_id=pdu.event_id,
|
event_id=pdu.event_id,
|
||||||
|
room_version=room_version,
|
||||||
outlier=outlier,
|
outlier=outlier,
|
||||||
timeout=10000,
|
timeout=10000,
|
||||||
)
|
)
|
||||||
|
@ -207,7 +207,8 @@ class FederationClient(FederationBase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def get_pdu(self, destinations, event_id, outlier=False, timeout=None):
|
def get_pdu(self, destinations, event_id, room_version, outlier=False,
|
||||||
|
timeout=None):
|
||||||
"""Requests the PDU with given origin and ID from the remote home
|
"""Requests the PDU with given origin and ID from the remote home
|
||||||
servers.
|
servers.
|
||||||
|
|
||||||
@ -217,6 +218,7 @@ class FederationClient(FederationBase):
|
|||||||
Args:
|
Args:
|
||||||
destinations (list): Which home servers to query
|
destinations (list): Which home servers to query
|
||||||
event_id (str): event to fetch
|
event_id (str): event to fetch
|
||||||
|
room_version (str): version of the room
|
||||||
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
|
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
|
||||||
it's from an arbitary point in the context as opposed to part
|
it's from an arbitary point in the context as opposed to part
|
||||||
of the current block of PDUs. Defaults to `False`
|
of the current block of PDUs. Defaults to `False`
|
||||||
@ -357,10 +359,13 @@ class FederationClient(FederationBase):
|
|||||||
ev.event_id for ev in itertools.chain(pdus, auth_chain)
|
ev.event_id for ev in itertools.chain(pdus, auth_chain)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(room_id)
|
||||||
|
|
||||||
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
|
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
|
||||||
destination,
|
destination,
|
||||||
[p for p in pdus if p.event_id not in seen_events],
|
[p for p in pdus if p.event_id not in seen_events],
|
||||||
outlier=True
|
outlier=True,
|
||||||
|
room_version=room_version,
|
||||||
)
|
)
|
||||||
signed_pdus.extend(
|
signed_pdus.extend(
|
||||||
seen_events[p.event_id] for p in pdus if p.event_id in seen_events
|
seen_events[p.event_id] for p in pdus if p.event_id in seen_events
|
||||||
@ -369,7 +374,8 @@ class FederationClient(FederationBase):
|
|||||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||||
destination,
|
destination,
|
||||||
[p for p in auth_chain if p.event_id not in seen_events],
|
[p for p in auth_chain if p.event_id not in seen_events],
|
||||||
outlier=True
|
outlier=True,
|
||||||
|
room_version=room_version,
|
||||||
)
|
)
|
||||||
signed_auth.extend(
|
signed_auth.extend(
|
||||||
seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
|
seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
|
||||||
@ -416,6 +422,8 @@ class FederationClient(FederationBase):
|
|||||||
random.shuffle(srvs)
|
random.shuffle(srvs)
|
||||||
return srvs
|
return srvs
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(room_id)
|
||||||
|
|
||||||
batch_size = 20
|
batch_size = 20
|
||||||
missing_events = list(missing_events)
|
missing_events = list(missing_events)
|
||||||
for i in range(0, len(missing_events), batch_size):
|
for i in range(0, len(missing_events), batch_size):
|
||||||
@ -426,6 +434,7 @@ class FederationClient(FederationBase):
|
|||||||
self.get_pdu,
|
self.get_pdu,
|
||||||
destinations=random_server_list(),
|
destinations=random_server_list(),
|
||||||
event_id=e_id,
|
event_id=e_id,
|
||||||
|
room_version=room_version,
|
||||||
)
|
)
|
||||||
for e_id in batch
|
for e_id in batch
|
||||||
]
|
]
|
||||||
@ -455,8 +464,11 @@ class FederationClient(FederationBase):
|
|||||||
for p in res["auth_chain"]
|
for p in res["auth_chain"]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(room_id)
|
||||||
|
|
||||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||||
destination, auth_chain, outlier=True
|
destination, auth_chain,
|
||||||
|
outlier=True, room_version=room_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
signed_auth.sort(key=lambda e: e.depth)
|
signed_auth.sort(key=lambda e: e.depth)
|
||||||
@ -661,9 +673,20 @@ class FederationClient(FederationBase):
|
|||||||
for p in itertools.chain(state, auth_chain)
|
for p in itertools.chain(state, auth_chain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
room_version = None
|
||||||
|
for e in state:
|
||||||
|
if (e.type, e.state_key) == (EventTypes.Create, ""):
|
||||||
|
room_version = e.content.get("room_version", RoomVersions.V1)
|
||||||
|
break
|
||||||
|
|
||||||
|
if room_version is None:
|
||||||
|
# We use this error has that is what
|
||||||
|
raise SynapseError(400, "No create event in state")
|
||||||
|
|
||||||
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
|
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
|
||||||
destination, list(pdus.values()),
|
destination, list(pdus.values()),
|
||||||
outlier=True,
|
outlier=True,
|
||||||
|
room_version=room_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_pdus_map = {
|
valid_pdus_map = {
|
||||||
@ -801,8 +824,10 @@ class FederationClient(FederationBase):
|
|||||||
for e in content["auth_chain"]
|
for e in content["auth_chain"]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(room_id)
|
||||||
|
|
||||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||||
destination, auth_chain, outlier=True
|
destination, auth_chain, outlier=True, room_version=room_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
signed_auth.sort(key=lambda e: e.depth)
|
signed_auth.sort(key=lambda e: e.depth)
|
||||||
@ -849,8 +874,10 @@ class FederationClient(FederationBase):
|
|||||||
for e in content.get("events", [])
|
for e in content.get("events", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(room_id)
|
||||||
|
|
||||||
signed_events = yield self._check_sigs_and_hash_and_fetch(
|
signed_events = yield self._check_sigs_and_hash_and_fetch(
|
||||||
destination, events, outlier=False
|
destination, events, outlier=False, room_version=room_version,
|
||||||
)
|
)
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
if not e.code == 400:
|
if not e.code == 400:
|
||||||
|
@ -463,8 +463,10 @@ class FederationServer(FederationBase):
|
|||||||
for e in content["auth_chain"]
|
for e in content["auth_chain"]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(room_id)
|
||||||
|
|
||||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||||
origin, auth_chain, outlier=True
|
origin, auth_chain, outlier=True, room_version=room_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
ret = yield self.handler.on_query_auth(
|
ret = yield self.handler.on_query_auth(
|
||||||
|
147
synapse/handlers/acme.py
Normal file
147
synapse/handlers/acme.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2019 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import attr
|
||||||
|
from zope.interface import implementer
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
from twisted.internet.endpoints import serverFromString
|
||||||
|
from twisted.python.filepath import FilePath
|
||||||
|
from twisted.python.url import URL
|
||||||
|
from twisted.web import server, static
|
||||||
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from txacme.interfaces import ICertificateStore
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
@implementer(ICertificateStore)
|
||||||
|
class ErsatzStore(object):
|
||||||
|
"""
|
||||||
|
A store that only stores in memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
certs = attr.ib(default=attr.Factory(dict))
|
||||||
|
|
||||||
|
def store(self, server_name, pem_objects):
|
||||||
|
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
|
||||||
|
return defer.succeed(None)
|
||||||
|
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# txacme is missing
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AcmeHandler(object):
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.hs = hs
|
||||||
|
self.reactor = hs.get_reactor()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def start_listening(self):
|
||||||
|
|
||||||
|
# Configure logging for txacme, if you need to debug
|
||||||
|
# from eliot import add_destinations
|
||||||
|
# from eliot.twisted import TwistedDestination
|
||||||
|
#
|
||||||
|
# add_destinations(TwistedDestination())
|
||||||
|
|
||||||
|
from txacme.challenges import HTTP01Responder
|
||||||
|
from txacme.service import AcmeIssuingService
|
||||||
|
from txacme.endpoint import load_or_create_client_key
|
||||||
|
from txacme.client import Client
|
||||||
|
from josepy.jwa import RS256
|
||||||
|
|
||||||
|
self._store = ErsatzStore()
|
||||||
|
responder = HTTP01Responder()
|
||||||
|
|
||||||
|
self._issuer = AcmeIssuingService(
|
||||||
|
cert_store=self._store,
|
||||||
|
client_creator=(
|
||||||
|
lambda: Client.from_url(
|
||||||
|
reactor=self.reactor,
|
||||||
|
url=URL.from_text(self.hs.config.acme_url),
|
||||||
|
key=load_or_create_client_key(
|
||||||
|
FilePath(self.hs.config.config_dir_path)
|
||||||
|
),
|
||||||
|
alg=RS256,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
clock=self.reactor,
|
||||||
|
responders=[responder],
|
||||||
|
)
|
||||||
|
|
||||||
|
well_known = Resource()
|
||||||
|
well_known.putChild(b'acme-challenge', responder.resource)
|
||||||
|
responder_resource = Resource()
|
||||||
|
responder_resource.putChild(b'.well-known', well_known)
|
||||||
|
responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain'))
|
||||||
|
|
||||||
|
srv = server.Site(responder_resource)
|
||||||
|
|
||||||
|
listeners = []
|
||||||
|
|
||||||
|
for host in self.hs.config.acme_bind_addresses:
|
||||||
|
logger.info(
|
||||||
|
"Listening for ACME requests on %s:%s", host, self.hs.config.acme_port
|
||||||
|
)
|
||||||
|
endpoint = serverFromString(
|
||||||
|
self.reactor, "tcp:%s:interface=%s" % (self.hs.config.acme_port, host)
|
||||||
|
)
|
||||||
|
listeners.append(endpoint.listen(srv))
|
||||||
|
|
||||||
|
# Make sure we are registered to the ACME server. There's no public API
|
||||||
|
# for this, it is usually triggered by startService, but since we don't
|
||||||
|
# want it to control where we save the certificates, we have to reach in
|
||||||
|
# and trigger the registration machinery ourselves.
|
||||||
|
self._issuer._registered = False
|
||||||
|
yield self._issuer._ensure_registered()
|
||||||
|
|
||||||
|
# Return a Deferred that will fire when all the servers have started up.
|
||||||
|
yield defer.DeferredList(listeners, fireOnOneErrback=True, consumeErrors=True)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def provision_certificate(self):
|
||||||
|
|
||||||
|
logger.warning("Reprovisioning %s", self.hs.hostname)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self._issuer.issue_cert(self.hs.hostname)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Fail!")
|
||||||
|
raise
|
||||||
|
logger.warning("Reprovisioned %s, saving.", self.hs.hostname)
|
||||||
|
cert_chain = self._store.certs[self.hs.hostname]
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.hs.config.tls_private_key_file, "wb") as private_key_file:
|
||||||
|
for x in cert_chain:
|
||||||
|
if x.startswith(b"-----BEGIN RSA PRIVATE KEY-----"):
|
||||||
|
private_key_file.write(x)
|
||||||
|
|
||||||
|
with open(self.hs.config.tls_certificate_file, "wb") as certificate_file:
|
||||||
|
for x in cert_chain:
|
||||||
|
if x.startswith(b"-----BEGIN CERTIFICATE-----"):
|
||||||
|
certificate_file.write(x)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed saving!")
|
||||||
|
raise
|
||||||
|
|
||||||
|
defer.returnValue(True)
|
@ -34,6 +34,7 @@ from synapse.api.constants import (
|
|||||||
EventTypes,
|
EventTypes,
|
||||||
Membership,
|
Membership,
|
||||||
RejectedReason,
|
RejectedReason,
|
||||||
|
RoomVersions,
|
||||||
)
|
)
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError,
|
AuthError,
|
||||||
@ -342,6 +343,8 @@ class FederationHandler(BaseHandler):
|
|||||||
room_id, event_id, p,
|
room_id, event_id, p,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(room_id)
|
||||||
|
|
||||||
with logcontext.nested_logging_context(p):
|
with logcontext.nested_logging_context(p):
|
||||||
# note that if any of the missing prevs share missing state or
|
# note that if any of the missing prevs share missing state or
|
||||||
# auth events, the requests to fetch those events are deduped
|
# auth events, the requests to fetch those events are deduped
|
||||||
@ -355,7 +358,7 @@ class FederationHandler(BaseHandler):
|
|||||||
# we want the state *after* p; get_state_for_room returns the
|
# we want the state *after* p; get_state_for_room returns the
|
||||||
# state *before* p.
|
# state *before* p.
|
||||||
remote_event = yield self.federation_client.get_pdu(
|
remote_event = yield self.federation_client.get_pdu(
|
||||||
[origin], p, outlier=True,
|
[origin], p, room_version, outlier=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if remote_event is None:
|
if remote_event is None:
|
||||||
@ -379,7 +382,6 @@ class FederationHandler(BaseHandler):
|
|||||||
for x in remote_state:
|
for x in remote_state:
|
||||||
event_map[x.event_id] = x
|
event_map[x.event_id] = x
|
||||||
|
|
||||||
room_version = yield self.store.get_room_version(room_id)
|
|
||||||
state_map = yield resolve_events_with_store(
|
state_map = yield resolve_events_with_store(
|
||||||
room_version, state_maps, event_map,
|
room_version, state_maps, event_map,
|
||||||
state_res_store=StateResolutionStore(self.store),
|
state_res_store=StateResolutionStore(self.store),
|
||||||
@ -655,6 +657,8 @@ class FederationHandler(BaseHandler):
|
|||||||
if dest == self.server_name:
|
if dest == self.server_name:
|
||||||
raise SynapseError(400, "Can't backfill from self.")
|
raise SynapseError(400, "Can't backfill from self.")
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(room_id)
|
||||||
|
|
||||||
events = yield self.federation_client.backfill(
|
events = yield self.federation_client.backfill(
|
||||||
dest,
|
dest,
|
||||||
room_id,
|
room_id,
|
||||||
@ -748,6 +752,7 @@ class FederationHandler(BaseHandler):
|
|||||||
self.federation_client.get_pdu,
|
self.federation_client.get_pdu,
|
||||||
[dest],
|
[dest],
|
||||||
event_id,
|
event_id,
|
||||||
|
room_version=room_version,
|
||||||
outlier=True,
|
outlier=True,
|
||||||
timeout=10000,
|
timeout=10000,
|
||||||
)
|
)
|
||||||
@ -1659,6 +1664,8 @@ class FederationHandler(BaseHandler):
|
|||||||
create_event = e
|
create_event = e
|
||||||
break
|
break
|
||||||
|
|
||||||
|
room_version = create_event.content.get("room_version", RoomVersions.V1)
|
||||||
|
|
||||||
missing_auth_events = set()
|
missing_auth_events = set()
|
||||||
for e in itertools.chain(auth_events, state, [event]):
|
for e in itertools.chain(auth_events, state, [event]):
|
||||||
for e_id in e.auth_event_ids():
|
for e_id in e.auth_event_ids():
|
||||||
@ -1669,6 +1676,7 @@ class FederationHandler(BaseHandler):
|
|||||||
m_ev = yield self.federation_client.get_pdu(
|
m_ev = yield self.federation_client.get_pdu(
|
||||||
[origin],
|
[origin],
|
||||||
e_id,
|
e_id,
|
||||||
|
room_version=room_version,
|
||||||
outlier=True,
|
outlier=True,
|
||||||
timeout=10000,
|
timeout=10000,
|
||||||
)
|
)
|
||||||
|
@ -19,6 +19,7 @@ from six import iteritems
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse.metrics
|
||||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage.roommember import ProfileInfo
|
from synapse.storage.roommember import ProfileInfo
|
||||||
@ -163,6 +164,11 @@ class UserDirectoryHandler(object):
|
|||||||
yield self._handle_deltas(deltas)
|
yield self._handle_deltas(deltas)
|
||||||
|
|
||||||
self.pos = deltas[-1]["stream_id"]
|
self.pos = deltas[-1]["stream_id"]
|
||||||
|
|
||||||
|
# Expose current event processing position to prometheus
|
||||||
|
synapse.metrics.event_processing_positions.labels(
|
||||||
|
"user_dir").set(self.pos)
|
||||||
|
|
||||||
yield self.store.update_user_directory_stream_pos(self.pos)
|
yield self.store.update_user_directory_stream_pos(self.pos)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -13,15 +13,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import random
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
|
||||||
from twisted.internet.error import ConnectError
|
|
||||||
|
|
||||||
from synapse.http.federation.srv_resolver import Server, resolve_service
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -88,140 +81,3 @@ def parse_and_validate_server_name(server_name):
|
|||||||
))
|
))
|
||||||
|
|
||||||
return host, port
|
return host, port
|
||||||
|
|
||||||
|
|
||||||
def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None,
|
|
||||||
timeout=None):
|
|
||||||
"""Construct an endpoint for the given matrix destination.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
reactor: Twisted reactor.
|
|
||||||
destination (unicode): The name of the server to connect to.
|
|
||||||
tls_client_options_factory
|
|
||||||
(synapse.crypto.context_factory.ClientTLSOptionsFactory):
|
|
||||||
Factory which generates TLS options for client connections.
|
|
||||||
timeout (int): connection timeout in seconds
|
|
||||||
"""
|
|
||||||
|
|
||||||
domain, port = parse_server_name(destination)
|
|
||||||
|
|
||||||
endpoint_kw_args = {}
|
|
||||||
|
|
||||||
if timeout is not None:
|
|
||||||
endpoint_kw_args.update(timeout=timeout)
|
|
||||||
|
|
||||||
if tls_client_options_factory is None:
|
|
||||||
transport_endpoint = HostnameEndpoint
|
|
||||||
default_port = 8008
|
|
||||||
else:
|
|
||||||
# the SNI string should be the same as the Host header, minus the port.
|
|
||||||
# as per https://github.com/matrix-org/synapse/issues/2525#issuecomment-336896777,
|
|
||||||
# the Host header and SNI should therefore be the server_name of the remote
|
|
||||||
# server.
|
|
||||||
tls_options = tls_client_options_factory.get_options(domain)
|
|
||||||
|
|
||||||
def transport_endpoint(reactor, host, port, timeout):
|
|
||||||
return wrapClientTLS(
|
|
||||||
tls_options,
|
|
||||||
HostnameEndpoint(reactor, host, port, timeout=timeout),
|
|
||||||
)
|
|
||||||
default_port = 8448
|
|
||||||
|
|
||||||
if port is None:
|
|
||||||
return SRVClientEndpoint(
|
|
||||||
reactor, "matrix", domain, protocol="tcp",
|
|
||||||
default_port=default_port, endpoint=transport_endpoint,
|
|
||||||
endpoint_kw_args=endpoint_kw_args
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return transport_endpoint(
|
|
||||||
reactor, domain, port, **endpoint_kw_args
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SRVClientEndpoint(object):
|
|
||||||
"""An endpoint which looks up SRV records for a service.
|
|
||||||
Cycles through the list of servers starting with each call to connect
|
|
||||||
picking the next server.
|
|
||||||
Implements twisted.internet.interfaces.IStreamClientEndpoint.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, reactor, service, domain, protocol="tcp",
|
|
||||||
default_port=None, endpoint=HostnameEndpoint,
|
|
||||||
endpoint_kw_args={}):
|
|
||||||
self.reactor = reactor
|
|
||||||
self.service_name = "_%s._%s.%s" % (service, protocol, domain)
|
|
||||||
|
|
||||||
if default_port is not None:
|
|
||||||
self.default_server = Server(
|
|
||||||
host=domain,
|
|
||||||
port=default_port,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.default_server = None
|
|
||||||
|
|
||||||
self.endpoint = endpoint
|
|
||||||
self.endpoint_kw_args = endpoint_kw_args
|
|
||||||
|
|
||||||
self.servers = None
|
|
||||||
self.used_servers = None
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def fetch_servers(self):
|
|
||||||
self.used_servers = []
|
|
||||||
self.servers = yield resolve_service(self.service_name)
|
|
||||||
|
|
||||||
def pick_server(self):
|
|
||||||
if not self.servers:
|
|
||||||
if self.used_servers:
|
|
||||||
self.servers = self.used_servers
|
|
||||||
self.used_servers = []
|
|
||||||
self.servers.sort()
|
|
||||||
elif self.default_server:
|
|
||||||
return self.default_server
|
|
||||||
else:
|
|
||||||
raise ConnectError(
|
|
||||||
"No server available for %s" % self.service_name
|
|
||||||
)
|
|
||||||
|
|
||||||
# look for all servers with the same priority
|
|
||||||
min_priority = self.servers[0].priority
|
|
||||||
weight_indexes = list(
|
|
||||||
(index, server.weight + 1)
|
|
||||||
for index, server in enumerate(self.servers)
|
|
||||||
if server.priority == min_priority
|
|
||||||
)
|
|
||||||
|
|
||||||
total_weight = sum(weight for index, weight in weight_indexes)
|
|
||||||
target_weight = random.randint(0, total_weight)
|
|
||||||
for index, weight in weight_indexes:
|
|
||||||
target_weight -= weight
|
|
||||||
if target_weight <= 0:
|
|
||||||
server = self.servers[index]
|
|
||||||
# XXX: this looks totally dubious:
|
|
||||||
#
|
|
||||||
# (a) we never reuse a server until we have been through
|
|
||||||
# all of the servers at the same priority, so if the
|
|
||||||
# weights are A: 100, B:1, we always do ABABAB instead of
|
|
||||||
# AAAA...AAAB (approximately).
|
|
||||||
#
|
|
||||||
# (b) After using all the servers at the lowest priority,
|
|
||||||
# we move onto the next priority. We should only use the
|
|
||||||
# second priority if servers at the top priority are
|
|
||||||
# unreachable.
|
|
||||||
#
|
|
||||||
del self.servers[index]
|
|
||||||
self.used_servers.append(server)
|
|
||||||
return server
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def connect(self, protocolFactory):
|
|
||||||
if self.servers is None:
|
|
||||||
yield self.fetch_servers()
|
|
||||||
server = self.pick_server()
|
|
||||||
logger.info("Connecting to %s:%s", server.host, server.port)
|
|
||||||
endpoint = self.endpoint(
|
|
||||||
self.reactor, server.host, server.port, **self.endpoint_kw_args
|
|
||||||
)
|
|
||||||
connection = yield endpoint.connect(protocolFactory)
|
|
||||||
defer.returnValue(connection)
|
|
||||||
|
124
synapse/http/federation/matrix_federation_agent.py
Normal file
124
synapse/http/federation/matrix_federation_agent.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2019 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from zope.interface import implementer
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||||
|
from twisted.web.client import URI, Agent, HTTPConnectionPool
|
||||||
|
from twisted.web.iweb import IAgent
|
||||||
|
|
||||||
|
from synapse.http.endpoint import parse_server_name
|
||||||
|
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
|
||||||
|
from synapse.util.logcontext import make_deferred_yieldable
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@implementer(IAgent)
|
||||||
|
class MatrixFederationAgent(object):
|
||||||
|
"""An Agent-like thing which provides a `request` method which will look up a matrix
|
||||||
|
server and send an HTTP request to it.
|
||||||
|
|
||||||
|
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reactor (IReactor): twisted reactor to use for underlying requests
|
||||||
|
|
||||||
|
tls_client_options_factory (ClientTLSOptionsFactory|None):
|
||||||
|
factory to use for fetching client tls options, or none to disable TLS.
|
||||||
|
|
||||||
|
srv_resolver (SrvResolver|None):
|
||||||
|
SRVResolver impl to use for looking up SRV records. None to use a default
|
||||||
|
implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, reactor, tls_client_options_factory, _srv_resolver=None,
|
||||||
|
):
|
||||||
|
self._reactor = reactor
|
||||||
|
self._tls_client_options_factory = tls_client_options_factory
|
||||||
|
if _srv_resolver is None:
|
||||||
|
_srv_resolver = SrvResolver()
|
||||||
|
self._srv_resolver = _srv_resolver
|
||||||
|
|
||||||
|
self._pool = HTTPConnectionPool(reactor)
|
||||||
|
self._pool.retryAutomatically = False
|
||||||
|
self._pool.maxPersistentPerHost = 5
|
||||||
|
self._pool.cachedConnectionTimeout = 2 * 60
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def request(self, method, uri, headers=None, bodyProducer=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
method (bytes): HTTP method: GET/POST/etc
|
||||||
|
|
||||||
|
uri (bytes): Absolute URI to be retrieved
|
||||||
|
|
||||||
|
headers (twisted.web.http_headers.Headers|None):
|
||||||
|
HTTP headers to send with the request, or None to
|
||||||
|
send no extra headers.
|
||||||
|
|
||||||
|
bodyProducer (twisted.web.iweb.IBodyProducer|None):
|
||||||
|
An object which can generate bytes to make up the
|
||||||
|
body of this request (for example, the properly encoded contents of
|
||||||
|
a file for a file upload). Or None if the request is to have
|
||||||
|
no body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[twisted.web.iweb.IResponse]:
|
||||||
|
fires when the header of the response has been received (regardless of the
|
||||||
|
response status code). Fails if there is any problem which prevents that
|
||||||
|
response from being received (including problems that prevent the request
|
||||||
|
from being sent).
|
||||||
|
"""
|
||||||
|
|
||||||
|
parsed_uri = URI.fromBytes(uri)
|
||||||
|
server_name_bytes = parsed_uri.netloc
|
||||||
|
host, port = parse_server_name(server_name_bytes.decode("ascii"))
|
||||||
|
|
||||||
|
# XXX disabling TLS is really only supported here for the benefit of the
|
||||||
|
# unit tests. We should make the UTs cope with TLS rather than having to make
|
||||||
|
# the code support the unit tests.
|
||||||
|
if self._tls_client_options_factory is None:
|
||||||
|
tls_options = None
|
||||||
|
else:
|
||||||
|
tls_options = self._tls_client_options_factory.get_options(host)
|
||||||
|
|
||||||
|
if port is not None:
|
||||||
|
target = (host, port)
|
||||||
|
else:
|
||||||
|
server_list = yield self._srv_resolver.resolve_service(server_name_bytes)
|
||||||
|
if not server_list:
|
||||||
|
target = (host, 8448)
|
||||||
|
logger.debug("No SRV record for %s, using %s", host, target)
|
||||||
|
else:
|
||||||
|
target = pick_server_from_list(server_list)
|
||||||
|
|
||||||
|
class EndpointFactory(object):
|
||||||
|
@staticmethod
|
||||||
|
def endpointForURI(_uri):
|
||||||
|
logger.info("Connecting to %s:%s", target[0], target[1])
|
||||||
|
ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1])
|
||||||
|
if tls_options is not None:
|
||||||
|
ep = wrapClientTLS(tls_options, ep)
|
||||||
|
return ep
|
||||||
|
|
||||||
|
agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
|
||||||
|
res = yield make_deferred_yieldable(
|
||||||
|
agent.request(method, uri, headers, bodyProducer)
|
||||||
|
)
|
||||||
|
defer.returnValue(res)
|
@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
@ -51,34 +52,79 @@ class Server(object):
|
|||||||
expires = attr.ib(default=0)
|
expires = attr.ib(default=0)
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
def pick_server_from_list(server_list):
|
||||||
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
|
"""Randomly choose a server from the server list
|
||||||
"""Look up a SRV record, with caching
|
|
||||||
|
Args:
|
||||||
|
server_list (list[Server]): list of candidate servers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bytes, int]: (host, port) pair for the chosen server
|
||||||
|
"""
|
||||||
|
if not server_list:
|
||||||
|
raise RuntimeError("pick_server_from_list called with empty list")
|
||||||
|
|
||||||
|
# TODO: currently we only use the lowest-priority servers. We should maintain a
|
||||||
|
# cache of servers known to be "down" and filter them out
|
||||||
|
|
||||||
|
min_priority = min(s.priority for s in server_list)
|
||||||
|
eligible_servers = list(s for s in server_list if s.priority == min_priority)
|
||||||
|
total_weight = sum(s.weight for s in eligible_servers)
|
||||||
|
target_weight = random.randint(0, total_weight)
|
||||||
|
|
||||||
|
for s in eligible_servers:
|
||||||
|
target_weight -= s.weight
|
||||||
|
|
||||||
|
if target_weight <= 0:
|
||||||
|
return s.host, s.port
|
||||||
|
|
||||||
|
# this should be impossible.
|
||||||
|
raise RuntimeError(
|
||||||
|
"pick_server_from_list got to end of eligible server list.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SrvResolver(object):
|
||||||
|
"""Interface to the dns client to do SRV lookups, with result caching.
|
||||||
|
|
||||||
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
|
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
|
||||||
but the cache never gets populated), so we add our own caching layer here.
|
but the cache never gets populated), so we add our own caching layer here.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service_name (unicode|bytes): record to look up
|
|
||||||
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
|
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
|
||||||
cache (dict): cache object
|
cache (dict): cache object
|
||||||
clock (object): clock implementation. must provide a time() method.
|
get_time (callable): clock implementation. Should return seconds since the epoch
|
||||||
|
"""
|
||||||
|
def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
|
||||||
|
self._dns_client = dns_client
|
||||||
|
self._cache = cache
|
||||||
|
self._get_time = get_time
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def resolve_service(self, service_name):
|
||||||
|
"""Look up a SRV record
|
||||||
|
|
||||||
|
Args:
|
||||||
|
service_name (bytes): record to look up
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[list[Server]]: a list of the SRV records, or an empty list if none found
|
Deferred[list[Server]]:
|
||||||
|
a list of the SRV records, or an empty list if none found
|
||||||
"""
|
"""
|
||||||
# TODO: the dns client handles both unicode names (encoding via idna) and pre-encoded
|
now = int(self._get_time())
|
||||||
# byteses; however they will obviously end up as separate entries in the cache. We
|
|
||||||
# should pick one form and stick with it.
|
if not isinstance(service_name, bytes):
|
||||||
cache_entry = cache.get(service_name, None)
|
raise TypeError("%r is not a byte string" % (service_name,))
|
||||||
|
|
||||||
|
cache_entry = self._cache.get(service_name, None)
|
||||||
if cache_entry:
|
if cache_entry:
|
||||||
if all(s.expires > int(clock.time()) for s in cache_entry):
|
if all(s.expires > now for s in cache_entry):
|
||||||
servers = list(cache_entry)
|
servers = list(cache_entry)
|
||||||
defer.returnValue(servers)
|
defer.returnValue(servers)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
answers, _, _ = yield make_deferred_yieldable(
|
answers, _, _ = yield make_deferred_yieldable(
|
||||||
dns_client.lookupService(service_name),
|
self._dns_client.lookupService(service_name),
|
||||||
)
|
)
|
||||||
except DNSNameError:
|
except DNSNameError:
|
||||||
# TODO: cache this. We can get the SOA out of the exception, and use
|
# TODO: cache this. We can get the SOA out of the exception, and use
|
||||||
@ -87,7 +133,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
|
|||||||
except DomainError as e:
|
except DomainError as e:
|
||||||
# We failed to resolve the name (other than a NameError)
|
# We failed to resolve the name (other than a NameError)
|
||||||
# Try something in the cache, else rereaise
|
# Try something in the cache, else rereaise
|
||||||
cache_entry = cache.get(service_name, None)
|
cache_entry = self._cache.get(service_name, None)
|
||||||
if cache_entry:
|
if cache_entry:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Failed to resolve %r, falling back to cache. %r",
|
"Failed to resolve %r, falling back to cache. %r",
|
||||||
@ -116,9 +162,8 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
|
|||||||
port=payload.port,
|
port=payload.port,
|
||||||
priority=payload.priority,
|
priority=payload.priority,
|
||||||
weight=payload.weight,
|
weight=payload.weight,
|
||||||
expires=int(clock.time()) + answer.ttl,
|
expires=now + answer.ttl,
|
||||||
))
|
))
|
||||||
|
|
||||||
servers.sort() # FIXME: get rid of this (it's broken by the attrs change)
|
self._cache[service_name] = list(servers)
|
||||||
cache[service_name] = list(servers)
|
|
||||||
defer.returnValue(servers)
|
defer.returnValue(servers)
|
||||||
|
@ -32,7 +32,7 @@ from twisted.internet import defer, protocol
|
|||||||
from twisted.internet.error import DNSLookupError
|
from twisted.internet.error import DNSLookupError
|
||||||
from twisted.internet.task import _EPSILON, Cooperator
|
from twisted.internet.task import _EPSILON, Cooperator
|
||||||
from twisted.web._newclient import ResponseDone
|
from twisted.web._newclient import ResponseDone
|
||||||
from twisted.web.client import Agent, FileBodyProducer, HTTPConnectionPool
|
from twisted.web.client import FileBodyProducer
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
|
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
@ -44,7 +44,7 @@ from synapse.api.errors import (
|
|||||||
RequestSendFailed,
|
RequestSendFailed,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
from synapse.http.endpoint import matrix_federation_endpoint
|
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
|
||||||
from synapse.util.async_helpers import timeout_deferred
|
from synapse.util.async_helpers import timeout_deferred
|
||||||
from synapse.util.logcontext import make_deferred_yieldable
|
from synapse.util.logcontext import make_deferred_yieldable
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
@ -66,20 +66,6 @@ else:
|
|||||||
MAXINT = sys.maxint
|
MAXINT = sys.maxint
|
||||||
|
|
||||||
|
|
||||||
class MatrixFederationEndpointFactory(object):
|
|
||||||
def __init__(self, hs):
|
|
||||||
self.reactor = hs.get_reactor()
|
|
||||||
self.tls_client_options_factory = hs.tls_client_options_factory
|
|
||||||
|
|
||||||
def endpointForURI(self, uri):
|
|
||||||
destination = uri.netloc.decode('ascii')
|
|
||||||
|
|
||||||
return matrix_federation_endpoint(
|
|
||||||
self.reactor, destination, timeout=10,
|
|
||||||
tls_client_options_factory=self.tls_client_options_factory
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_next_id = 1
|
_next_id = 1
|
||||||
|
|
||||||
|
|
||||||
@ -187,12 +173,10 @@ class MatrixFederationHttpClient(object):
|
|||||||
self.signing_key = hs.config.signing_key[0]
|
self.signing_key = hs.config.signing_key[0]
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
reactor = hs.get_reactor()
|
reactor = hs.get_reactor()
|
||||||
pool = HTTPConnectionPool(reactor)
|
|
||||||
pool.retryAutomatically = False
|
self.agent = MatrixFederationAgent(
|
||||||
pool.maxPersistentPerHost = 5
|
hs.get_reactor(),
|
||||||
pool.cachedConnectionTimeout = 2 * 60
|
hs.tls_client_options_factory,
|
||||||
self.agent = Agent.usingEndpointFactory(
|
|
||||||
reactor, MatrixFederationEndpointFactory(hs), pool=pool
|
|
||||||
)
|
)
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self._store = hs.get_datastore()
|
self._store = hs.get_datastore()
|
||||||
@ -316,9 +300,9 @@ class MatrixFederationHttpClient(object):
|
|||||||
headers_dict[b"Authorization"] = auth_headers
|
headers_dict[b"Authorization"] = auth_headers
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"{%s} [%s] Sending request: %s %s",
|
"{%s} [%s] Sending request: %s %s; timeout %fs",
|
||||||
request.txn_id, request.destination, request.method,
|
request.txn_id, request.destination, request.method,
|
||||||
url_str,
|
url_str, _sec_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -338,12 +322,11 @@ class MatrixFederationHttpClient(object):
|
|||||||
reactor=self.hs.get_reactor(),
|
reactor=self.hs.get_reactor(),
|
||||||
)
|
)
|
||||||
|
|
||||||
response = yield make_deferred_yieldable(
|
response = yield request_deferred
|
||||||
request_deferred,
|
|
||||||
)
|
|
||||||
except DNSLookupError as e:
|
except DNSLookupError as e:
|
||||||
raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e)
|
raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.info("Failed to send request: %s", e)
|
||||||
raise_from(RequestSendFailed(e, can_retry=True), e)
|
raise_from(RequestSendFailed(e, can_retry=True), e)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -79,6 +79,10 @@ CONDITIONAL_REQUIREMENTS = {
|
|||||||
# ConsentResource uses select_autoescape, which arrived in jinja 2.9
|
# ConsentResource uses select_autoescape, which arrived in jinja 2.9
|
||||||
"resources.consent": ["Jinja2>=2.9"],
|
"resources.consent": ["Jinja2>=2.9"],
|
||||||
|
|
||||||
|
# ACME support is required to provision TLS certificates from authorities
|
||||||
|
# that use the protocol, such as Let's Encrypt.
|
||||||
|
"acme": ["txacme>=0.9.2"],
|
||||||
|
|
||||||
"saml2": ["pysaml2>=4.5.0"],
|
"saml2": ["pysaml2>=4.5.0"],
|
||||||
"url_preview": ["lxml>=3.5.0"],
|
"url_preview": ["lxml>=3.5.0"],
|
||||||
"test": ["mock>=2.0"],
|
"test": ["mock>=2.0"],
|
||||||
|
@ -46,6 +46,7 @@ from synapse.federation.transport.client import TransportLayerClient
|
|||||||
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
|
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
|
||||||
from synapse.groups.groups_server import GroupsServerHandler
|
from synapse.groups.groups_server import GroupsServerHandler
|
||||||
from synapse.handlers import Handlers
|
from synapse.handlers import Handlers
|
||||||
|
from synapse.handlers.acme import AcmeHandler
|
||||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||||
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
|
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
|
||||||
from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
||||||
@ -129,6 +130,7 @@ class HomeServer(object):
|
|||||||
'sync_handler',
|
'sync_handler',
|
||||||
'typing_handler',
|
'typing_handler',
|
||||||
'room_list_handler',
|
'room_list_handler',
|
||||||
|
'acme_handler',
|
||||||
'auth_handler',
|
'auth_handler',
|
||||||
'device_handler',
|
'device_handler',
|
||||||
'e2e_keys_handler',
|
'e2e_keys_handler',
|
||||||
@ -310,6 +312,9 @@ class HomeServer(object):
|
|||||||
def build_e2e_room_keys_handler(self):
|
def build_e2e_room_keys_handler(self):
|
||||||
return E2eRoomKeysHandler(self)
|
return E2eRoomKeysHandler(self)
|
||||||
|
|
||||||
|
def build_acme_handler(self):
|
||||||
|
return AcmeHandler(self)
|
||||||
|
|
||||||
def build_application_service_api(self):
|
def build_application_service_api(self):
|
||||||
return ApplicationServiceApi(self)
|
return ApplicationServiceApi(self)
|
||||||
|
|
||||||
|
@ -143,6 +143,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||||||
# If it returns None, then we're processing the last batch
|
# If it returns None, then we're processing the last batch
|
||||||
last = end_last_seen is None
|
last = end_last_seen is None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s",
|
||||||
|
begin_last_seen, end_last_seen,
|
||||||
|
)
|
||||||
|
|
||||||
def remove(txn):
|
def remove(txn):
|
||||||
# This works by looking at all entries in the given time span, and
|
# This works by looking at all entries in the given time span, and
|
||||||
# then for each (user_id, access_token, ip) tuple in that range
|
# then for each (user_id, access_token, ip) tuple in that range
|
||||||
@ -170,7 +175,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||||||
SELECT user_id, access_token, ip
|
SELECT user_id, access_token, ip
|
||||||
FROM user_ips
|
FROM user_ips
|
||||||
WHERE {}
|
WHERE {}
|
||||||
ORDER BY last_seen
|
|
||||||
) c
|
) c
|
||||||
INNER JOIN user_ips USING (user_id, access_token, ip)
|
INNER JOIN user_ips USING (user_id, access_token, ip)
|
||||||
GROUP BY user_id, access_token, ip
|
GROUP BY user_id, access_token, ip
|
||||||
|
183
tests/http/federation/test_matrix_federation_agent.py
Normal file
183
tests/http/federation/test_matrix_federation_agent.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2019 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from mock import Mock
|
||||||
|
|
||||||
|
import treq
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
from twisted.internet.protocol import Factory
|
||||||
|
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||||
|
from twisted.test.ssl_helpers import ServerTLSContext
|
||||||
|
from twisted.web.http import HTTPChannel
|
||||||
|
|
||||||
|
from synapse.crypto.context_factory import ClientTLSOptionsFactory
|
||||||
|
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
|
||||||
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
|
||||||
|
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
||||||
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixFederationAgentTests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.reactor = ThreadedMemoryReactorClock()
|
||||||
|
|
||||||
|
self.mock_resolver = Mock()
|
||||||
|
|
||||||
|
self.agent = MatrixFederationAgent(
|
||||||
|
reactor=self.reactor,
|
||||||
|
tls_client_options_factory=ClientTLSOptionsFactory(None),
|
||||||
|
_srv_resolver=self.mock_resolver,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_connection(self, client_factory):
|
||||||
|
"""Builds a test server, and completes the outgoing client connection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HTTPChannel: the test server
|
||||||
|
"""
|
||||||
|
|
||||||
|
# build the test server
|
||||||
|
server_tls_protocol = _build_test_server()
|
||||||
|
|
||||||
|
# now, tell the client protocol factory to build the client protocol (it will be a
|
||||||
|
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
|
||||||
|
# HTTP11ClientProtocol) and wire the output of said protocol up to the server via
|
||||||
|
# a FakeTransport.
|
||||||
|
#
|
||||||
|
# Normally this would be done by the TCP socket code in Twisted, but we are
|
||||||
|
# stubbing that out here.
|
||||||
|
client_protocol = client_factory.buildProtocol(None)
|
||||||
|
client_protocol.makeConnection(FakeTransport(server_tls_protocol, self.reactor))
|
||||||
|
|
||||||
|
# tell the server tls protocol to send its stuff back to the client, too
|
||||||
|
server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor))
|
||||||
|
|
||||||
|
# finally, give the reactor a pump to get the TLS juices flowing.
|
||||||
|
self.reactor.pump((0.1,))
|
||||||
|
|
||||||
|
# fish the test server back out of the server-side TLS protocol.
|
||||||
|
return server_tls_protocol.wrappedProtocol
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _make_get_request(self, uri):
|
||||||
|
"""
|
||||||
|
Sends a simple GET request via the agent, and checks its logcontext management
|
||||||
|
"""
|
||||||
|
with LoggingContext("one") as context:
|
||||||
|
fetch_d = self.agent.request(b'GET', uri)
|
||||||
|
|
||||||
|
# Nothing happened yet
|
||||||
|
self.assertNoResult(fetch_d)
|
||||||
|
|
||||||
|
# should have reset logcontext to the sentinel
|
||||||
|
_check_logcontext(LoggingContext.sentinel)
|
||||||
|
|
||||||
|
try:
|
||||||
|
fetch_res = yield fetch_d
|
||||||
|
defer.returnValue(fetch_res)
|
||||||
|
finally:
|
||||||
|
_check_logcontext(context)
|
||||||
|
|
||||||
|
def test_get(self):
|
||||||
|
"""
|
||||||
|
happy-path test of a GET request
|
||||||
|
"""
|
||||||
|
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||||
|
test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar")
|
||||||
|
|
||||||
|
# Nothing happened yet
|
||||||
|
self.assertNoResult(test_d)
|
||||||
|
|
||||||
|
# Make sure treq is trying to connect
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
self.assertEqual(host, '1.2.3.4')
|
||||||
|
self.assertEqual(port, 8448)
|
||||||
|
|
||||||
|
# make a test server, and wire up the client
|
||||||
|
http_server = self._make_connection(client_factory)
|
||||||
|
|
||||||
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
request = http_server.requests[0]
|
||||||
|
self.assertEqual(request.method, b'GET')
|
||||||
|
self.assertEqual(request.path, b'/foo/bar')
|
||||||
|
self.assertEqual(
|
||||||
|
request.requestHeaders.getRawHeaders(b'host'),
|
||||||
|
[b'testserv:8448']
|
||||||
|
)
|
||||||
|
content = request.content.read()
|
||||||
|
self.assertEqual(content, b'')
|
||||||
|
|
||||||
|
# Deferred is still without a result
|
||||||
|
self.assertNoResult(test_d)
|
||||||
|
|
||||||
|
# send the headers
|
||||||
|
request.responseHeaders.setRawHeaders(b'Content-Type', [b'application/json'])
|
||||||
|
request.write('')
|
||||||
|
|
||||||
|
self.reactor.pump((0.1,))
|
||||||
|
|
||||||
|
response = self.successResultOf(test_d)
|
||||||
|
|
||||||
|
# that should give us a Response object
|
||||||
|
self.assertEqual(response.code, 200)
|
||||||
|
|
||||||
|
# Send the body
|
||||||
|
request.write('{ "a": 1 }'.encode('ascii'))
|
||||||
|
request.finish()
|
||||||
|
|
||||||
|
self.reactor.pump((0.1,))
|
||||||
|
|
||||||
|
# check it can be read
|
||||||
|
json = self.successResultOf(treq.json_content(response))
|
||||||
|
self.assertEqual(json, {"a": 1})
|
||||||
|
|
||||||
|
|
||||||
|
def _check_logcontext(context):
|
||||||
|
current = LoggingContext.current_context()
|
||||||
|
if current is not context:
|
||||||
|
raise AssertionError(
|
||||||
|
"Expected logcontext %s but was %s" % (context, current),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_test_server():
|
||||||
|
"""Construct a test server
|
||||||
|
|
||||||
|
This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TLSMemoryBIOProtocol
|
||||||
|
"""
|
||||||
|
server_factory = Factory.forProtocol(HTTPChannel)
|
||||||
|
# Request.finish expects the factory to have a 'log' method.
|
||||||
|
server_factory.log = _log_request
|
||||||
|
|
||||||
|
server_tls_factory = TLSMemoryBIOFactory(
|
||||||
|
ServerTLSContext(), isClient=False, wrappedFactory=server_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
return server_tls_factory.buildProtocol(None)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_request(request):
|
||||||
|
"""Implements Factory.log, which is expected by Request.finish"""
|
||||||
|
logger.info("Completed request %s", request)
|
@ -21,7 +21,7 @@ from twisted.internet.defer import Deferred
|
|||||||
from twisted.internet.error import ConnectError
|
from twisted.internet.error import ConnectError
|
||||||
from twisted.names import dns, error
|
from twisted.names import dns, error
|
||||||
|
|
||||||
from synapse.http.federation.srv_resolver import resolve_service
|
from synapse.http.federation.srv_resolver import SrvResolver
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
@ -43,13 +43,13 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
dns_client_mock.lookupService.return_value = result_deferred
|
dns_client_mock.lookupService.return_value = result_deferred
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_lookup():
|
def do_lookup():
|
||||||
|
|
||||||
with LoggingContext("one") as ctx:
|
with LoggingContext("one") as ctx:
|
||||||
resolve_d = resolve_service(
|
resolve_d = resolver.resolve_service(service_name)
|
||||||
service_name, dns_client=dns_client_mock, cache=cache
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertNoResult(resolve_d)
|
self.assertNoResult(resolve_d)
|
||||||
|
|
||||||
@ -83,16 +83,15 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
dns_client_mock = Mock()
|
dns_client_mock = Mock()
|
||||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
||||||
|
|
||||||
service_name = "test_service.example.com"
|
service_name = b"test_service.example.com"
|
||||||
|
|
||||||
entry = Mock(spec_set=["expires"])
|
entry = Mock(spec_set=["expires"])
|
||||||
entry.expires = 0
|
entry.expires = 0
|
||||||
|
|
||||||
cache = {service_name: [entry]}
|
cache = {service_name: [entry]}
|
||||||
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
servers = yield resolve_service(
|
servers = yield resolver.resolve_service(service_name)
|
||||||
service_name, dns_client=dns_client_mock, cache=cache
|
|
||||||
)
|
|
||||||
|
|
||||||
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
||||||
|
|
||||||
@ -106,17 +105,18 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
dns_client_mock = Mock(spec_set=['lookupService'])
|
dns_client_mock = Mock(spec_set=['lookupService'])
|
||||||
dns_client_mock.lookupService = Mock(spec_set=[])
|
dns_client_mock.lookupService = Mock(spec_set=[])
|
||||||
|
|
||||||
service_name = "test_service.example.com"
|
service_name = b"test_service.example.com"
|
||||||
|
|
||||||
entry = Mock(spec_set=["expires"])
|
entry = Mock(spec_set=["expires"])
|
||||||
entry.expires = 999999999
|
entry.expires = 999999999
|
||||||
|
|
||||||
cache = {service_name: [entry]}
|
cache = {service_name: [entry]}
|
||||||
|
resolver = SrvResolver(
|
||||||
servers = yield resolve_service(
|
dns_client=dns_client_mock, cache=cache, get_time=clock.time,
|
||||||
service_name, dns_client=dns_client_mock, cache=cache, clock=clock
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
servers = yield resolver.resolve_service(service_name)
|
||||||
|
|
||||||
self.assertFalse(dns_client_mock.lookupService.called)
|
self.assertFalse(dns_client_mock.lookupService.called)
|
||||||
|
|
||||||
self.assertEquals(len(servers), 1)
|
self.assertEquals(len(servers), 1)
|
||||||
@ -128,12 +128,13 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
||||||
|
|
||||||
service_name = "test_service.example.com"
|
service_name = b"test_service.example.com"
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
with self.assertRaises(error.DNSServerError):
|
with self.assertRaises(error.DNSServerError):
|
||||||
yield resolve_service(service_name, dns_client=dns_client_mock, cache=cache)
|
yield resolver.resolve_service(service_name)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_name_error(self):
|
def test_name_error(self):
|
||||||
@ -141,13 +142,12 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
|
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
|
||||||
|
|
||||||
service_name = "test_service.example.com"
|
service_name = b"test_service.example.com"
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
servers = yield resolve_service(
|
servers = yield resolver.resolve_service(service_name)
|
||||||
service_name, dns_client=dns_client_mock, cache=cache
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEquals(len(servers), 0)
|
self.assertEquals(len(servers), 0)
|
||||||
self.assertEquals(len(cache), 0)
|
self.assertEquals(len(cache), 0)
|
||||||
@ -162,10 +162,9 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
dns_client_mock = Mock()
|
dns_client_mock = Mock()
|
||||||
dns_client_mock.lookupService.return_value = lookup_deferred
|
dns_client_mock.lookupService.return_value = lookup_deferred
|
||||||
cache = {}
|
cache = {}
|
||||||
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
resolve_d = resolve_service(
|
resolve_d = resolver.resolve_service(service_name)
|
||||||
service_name, dns_client=dns_client_mock, cache=cache
|
|
||||||
)
|
|
||||||
self.assertNoResult(resolve_d)
|
self.assertNoResult(resolve_d)
|
||||||
|
|
||||||
# returning a single "." should make the lookup fail with a ConenctError
|
# returning a single "." should make the lookup fail with a ConenctError
|
||||||
@ -187,10 +186,9 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
dns_client_mock = Mock()
|
dns_client_mock = Mock()
|
||||||
dns_client_mock.lookupService.return_value = lookup_deferred
|
dns_client_mock.lookupService.return_value = lookup_deferred
|
||||||
cache = {}
|
cache = {}
|
||||||
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
resolve_d = resolve_service(
|
resolve_d = resolver.resolve_service(service_name)
|
||||||
service_name, dns_client=dns_client_mock, cache=cache
|
|
||||||
)
|
|
||||||
self.assertNoResult(resolve_d)
|
self.assertNoResult(resolve_d)
|
||||||
|
|
||||||
lookup_deferred.callback((
|
lookup_deferred.callback((
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
from twisted.internet.defer import TimeoutError
|
from twisted.internet.defer import TimeoutError
|
||||||
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
|
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
|
||||||
from twisted.test.proto_helpers import StringTransport
|
from twisted.test.proto_helpers import StringTransport
|
||||||
@ -26,11 +27,20 @@ from synapse.http.matrixfederationclient import (
|
|||||||
MatrixFederationHttpClient,
|
MatrixFederationHttpClient,
|
||||||
MatrixFederationRequest,
|
MatrixFederationRequest,
|
||||||
)
|
)
|
||||||
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
|
||||||
from tests.server import FakeTransport
|
from tests.server import FakeTransport
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
|
def check_logcontext(context):
|
||||||
|
current = LoggingContext.current_context()
|
||||||
|
if current is not context:
|
||||||
|
raise AssertionError(
|
||||||
|
"Expected logcontext %s but was %s" % (context, current),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FederationClientTests(HomeserverTestCase):
|
class FederationClientTests(HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor, clock):
|
||||||
|
|
||||||
@ -43,6 +53,70 @@ class FederationClientTests(HomeserverTestCase):
|
|||||||
self.cl = MatrixFederationHttpClient(self.hs)
|
self.cl = MatrixFederationHttpClient(self.hs)
|
||||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||||
|
|
||||||
|
def test_client_get(self):
|
||||||
|
"""
|
||||||
|
happy-path test of a GET request
|
||||||
|
"""
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def do_request():
|
||||||
|
with LoggingContext("one") as context:
|
||||||
|
fetch_d = self.cl.get_json("testserv:8008", "foo/bar")
|
||||||
|
|
||||||
|
# Nothing happened yet
|
||||||
|
self.assertNoResult(fetch_d)
|
||||||
|
|
||||||
|
# should have reset logcontext to the sentinel
|
||||||
|
check_logcontext(LoggingContext.sentinel)
|
||||||
|
|
||||||
|
try:
|
||||||
|
fetch_res = yield fetch_d
|
||||||
|
defer.returnValue(fetch_res)
|
||||||
|
finally:
|
||||||
|
check_logcontext(context)
|
||||||
|
|
||||||
|
test_d = do_request()
|
||||||
|
|
||||||
|
self.pump()
|
||||||
|
|
||||||
|
# Nothing happened yet
|
||||||
|
self.assertNoResult(test_d)
|
||||||
|
|
||||||
|
# Make sure treq is trying to connect
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
self.assertEqual(host, '1.2.3.4')
|
||||||
|
self.assertEqual(port, 8008)
|
||||||
|
|
||||||
|
# complete the connection and wire it up to a fake transport
|
||||||
|
protocol = factory.buildProtocol(None)
|
||||||
|
transport = StringTransport()
|
||||||
|
protocol.makeConnection(transport)
|
||||||
|
|
||||||
|
# that should have made it send the request to the transport
|
||||||
|
self.assertRegex(transport.value(), b"^GET /foo/bar")
|
||||||
|
|
||||||
|
# Deferred is still without a result
|
||||||
|
self.assertNoResult(test_d)
|
||||||
|
|
||||||
|
# Send it the HTTP response
|
||||||
|
res_json = '{ "a": 1 }'.encode('ascii')
|
||||||
|
protocol.dataReceived(
|
||||||
|
b"HTTP/1.1 200 OK\r\n"
|
||||||
|
b"Server: Fake\r\n"
|
||||||
|
b"Content-Type: application/json\r\n"
|
||||||
|
b"Content-Length: %i\r\n"
|
||||||
|
b"\r\n"
|
||||||
|
b"%s" % (len(res_json), res_json)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pump()
|
||||||
|
|
||||||
|
res = self.successResultOf(test_d)
|
||||||
|
|
||||||
|
# check the response is as expected
|
||||||
|
self.assertEqual(res, {"a": 1})
|
||||||
|
|
||||||
def test_dns_error(self):
|
def test_dns_error(self):
|
||||||
"""
|
"""
|
||||||
If the DNS lookup returns an error, it will bubble up.
|
If the DNS lookup returns an error, it will bubble up.
|
||||||
@ -54,6 +128,28 @@ class FederationClientTests(HomeserverTestCase):
|
|||||||
self.assertIsInstance(f.value, RequestSendFailed)
|
self.assertIsInstance(f.value, RequestSendFailed)
|
||||||
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
|
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
|
||||||
|
|
||||||
|
def test_client_connection_refused(self):
|
||||||
|
d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
|
||||||
|
|
||||||
|
self.pump()
|
||||||
|
|
||||||
|
# Nothing happened yet
|
||||||
|
self.assertNoResult(d)
|
||||||
|
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
self.assertEqual(host, '1.2.3.4')
|
||||||
|
self.assertEqual(port, 8008)
|
||||||
|
e = Exception("go away")
|
||||||
|
factory.clientConnectionFailed(None, e)
|
||||||
|
self.pump(0.5)
|
||||||
|
|
||||||
|
f = self.failureResultOf(d)
|
||||||
|
|
||||||
|
self.assertIsInstance(f.value, RequestSendFailed)
|
||||||
|
self.assertIs(f.value.inner_exception, e)
|
||||||
|
|
||||||
def test_client_never_connect(self):
|
def test_client_never_connect(self):
|
||||||
"""
|
"""
|
||||||
If the HTTP request is not connected and is timed out, it'll give a
|
If the HTTP request is not connected and is timed out, it'll give a
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
from six import text_type
|
from six import text_type
|
||||||
@ -22,6 +23,8 @@ from synapse.util import Clock
|
|||||||
|
|
||||||
from tests.utils import setup_test_homeserver as _sth
|
from tests.utils import setup_test_homeserver as _sth
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TimedOutException(Exception):
|
class TimedOutException(Exception):
|
||||||
"""
|
"""
|
||||||
@ -339,7 +342,7 @@ def get_clock():
|
|||||||
return (clock, hs_clock)
|
return (clock, hs_clock)
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s(cmp=False)
|
||||||
class FakeTransport(object):
|
class FakeTransport(object):
|
||||||
"""
|
"""
|
||||||
A twisted.internet.interfaces.ITransport implementation which sends all its data
|
A twisted.internet.interfaces.ITransport implementation which sends all its data
|
||||||
@ -414,6 +417,11 @@ class FakeTransport(object):
|
|||||||
self.buffer = self.buffer + byt
|
self.buffer = self.buffer + byt
|
||||||
|
|
||||||
def _write():
|
def _write():
|
||||||
|
if not self.buffer:
|
||||||
|
# nothing to do. Don't write empty buffers: it upsets the
|
||||||
|
# TLSMemoryBIOProtocol
|
||||||
|
return
|
||||||
|
|
||||||
if getattr(self.other, "transport") is not None:
|
if getattr(self.other, "transport") is not None:
|
||||||
self.other.dataReceived(self.buffer)
|
self.other.dataReceived(self.buffer)
|
||||||
self.buffer = b""
|
self.buffer = b""
|
||||||
@ -421,7 +429,10 @@ class FakeTransport(object):
|
|||||||
|
|
||||||
self._reactor.callLater(0.0, _write)
|
self._reactor.callLater(0.0, _write)
|
||||||
|
|
||||||
_write()
|
# always actually do the write asynchronously. Some protocols (notably the
|
||||||
|
# TLSMemoryBIOProtocol) get very confused if a read comes back while they are
|
||||||
|
# still doing a write. Doing a callLater here breaks the cycle.
|
||||||
|
self._reactor.callLater(0.0, _write)
|
||||||
|
|
||||||
def writeSequence(self, seq):
|
def writeSequence(self, seq):
|
||||||
for x in seq:
|
for x in seq:
|
||||||
|
Loading…
Reference in New Issue
Block a user