Merge branch 'develop' of github.com:matrix-org/synapse into erikj/redactions_eiah

This commit is contained in:
Erik Johnston 2019-01-29 22:00:33 +00:00
commit a696c48133
54 changed files with 1054 additions and 291 deletions

View File

@ -15,6 +15,7 @@ recursive-include docs *
recursive-include scripts * recursive-include scripts *
recursive-include scripts-dev * recursive-include scripts-dev *
recursive-include synapse *.pyi recursive-include synapse *.pyi
recursive-include tests *.pem
recursive-include tests *.py recursive-include tests *.py
recursive-include synapse/res * recursive-include synapse/res *

1
changelog.d/4408.feature Normal file
View File

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

View File

@ -1 +0,0 @@
Refactor 'sign_request' as 'build_auth_headers'

1
changelog.d/4409.feature Normal file
View File

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

View File

@ -1 +0,0 @@
Remove redundant federation connection wrapping code

1
changelog.d/4426.feature Normal file
View File

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

View File

@ -1 +0,0 @@
Remove redundant SynapseKeyClientProtocol magic

1
changelog.d/4427.feature Normal file
View File

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

View File

@ -1 +0,0 @@
Refactor and cleanup for SRV record lookup

1
changelog.d/4428.feature Normal file
View File

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

View File

@ -1 +0,0 @@
Move SRV logic into the Agent layer

1
changelog.d/4464.feature Normal file
View File

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

View File

@ -1 +0,0 @@
Move SRV logic into the Agent layer

1
changelog.d/4468.feature Normal file
View File

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

View File

@ -1 +0,0 @@
Move SRV logic into the Agent layer

1
changelog.d/4481.misc Normal file
View File

@ -0,0 +1 @@
Add infrastructure to support different event formats

1
changelog.d/4483.feature Normal file
View File

@ -0,0 +1 @@
Add support for room version 3

1
changelog.d/4487.feature Normal file
View File

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

View File

@ -1 +0,0 @@
Fix idna and ipv6 literal handling in MatrixFederationAgent

1
changelog.d/4489.feature Normal file
View File

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

1
changelog.d/4498.misc Normal file
View File

@ -0,0 +1 @@
Clarify documentation for the `public_baseurl` config param

1
changelog.d/4506.misc Normal file
View File

@ -0,0 +1 @@
Make it possible to set the log level for tests via an environment variable

1
changelog.d/4509.removal Normal file
View File

@ -0,0 +1 @@
Synapse no longer generates self-signed TLS certificates when generating a configuration file.

1
changelog.d/4510.misc Normal file
View File

@ -0,0 +1 @@
Add infrastructure to support different event formats

1
changelog.d/4511.feature Normal file
View File

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

1
changelog.d/4512.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug where setting a relative consent directory path would cause a crash.

View File

@ -550,17 +550,6 @@ class Auth(object):
""" """
return self.store.is_server_admin(user) return self.store.is_server_admin(user)
@defer.inlineCallbacks
def add_auth_events(self, builder, context):
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_ids = yield self.compute_auth_events(builder, prev_state_ids)
auth_events_entries = yield self.store.add_event_hashes(
auth_ids
)
builder.auth_events = auth_events_entries
@defer.inlineCallbacks @defer.inlineCallbacks
def compute_auth_events(self, event, current_state_ids, for_verification=False): def compute_auth_events(self, event, current_state_ids, for_verification=False):
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
@ -577,7 +566,7 @@ class Auth(object):
key = (EventTypes.JoinRules, "", ) key = (EventTypes.JoinRules, "", )
join_rule_event_id = current_state_ids.get(key) join_rule_event_id = current_state_ids.get(key)
key = (EventTypes.Member, event.user_id, ) key = (EventTypes.Member, event.sender, )
member_event_id = current_state_ids.get(key) member_event_id = current_state_ids.get(key)
key = (EventTypes.Create, "", ) key = (EventTypes.Create, "", )

View File

@ -125,10 +125,12 @@ class EventFormatVersions(object):
independently from the room version. independently from the room version.
""" """
V1 = 1 V1 = 1
V2 = 2
KNOWN_EVENT_FORMAT_VERSIONS = { KNOWN_EVENT_FORMAT_VERSIONS = {
EventFormatVersions.V1, EventFormatVersions.V1,
EventFormatVersions.V2,
} }

View File

@ -13,6 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from os import path
from synapse.config import ConfigError
from ._base import Config from ._base import Config
DEFAULT_CONFIG = """\ DEFAULT_CONFIG = """\
@ -85,7 +89,15 @@ class ConsentConfig(Config):
if consent_config is None: if consent_config is None:
return return
self.user_consent_version = str(consent_config["version"]) self.user_consent_version = str(consent_config["version"])
self.user_consent_template_dir = consent_config["template_dir"] self.user_consent_template_dir = self.abspath(
consent_config["template_dir"]
)
if not path.isdir(self.user_consent_template_dir):
raise ConfigError(
"Could not find template directory '%s'" % (
self.user_consent_template_dir,
),
)
self.user_consent_server_notice_content = consent_config.get( self.user_consent_server_notice_content = consent_config.get(
"server_notice_content", "server_notice_content",
) )

View File

@ -261,7 +261,7 @@ class ServerConfig(Config):
# enter into the 'custom HS URL' field on their client. If you # enter into the 'custom HS URL' field on their client. If you
# use synapse with a reverse proxy, this should be the URL to reach # use synapse with a reverse proxy, this should be the URL to reach
# synapse via the proxy. # synapse via the proxy.
# public_baseurl: https://example.com:8448/ # public_baseurl: https://example.com/
# Set the soft limit on the number of file descriptors synapse can use # Set the soft limit on the number of file descriptors synapse can use
# Zero is used to indicate synapse should set the soft limit to the # Zero is used to indicate synapse should set the soft limit to the

View File

@ -15,6 +15,7 @@
import logging import logging
import os import os
import warnings
from datetime import datetime from datetime import datetime
from hashlib import sha256 from hashlib import sha256
@ -39,8 +40,8 @@ class TlsConfig(Config):
self.acme_bind_addresses = acme_config.get("bind_addresses", ["127.0.0.1"]) self.acme_bind_addresses = acme_config.get("bind_addresses", ["127.0.0.1"])
self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30) self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30)
self.tls_certificate_file = os.path.abspath(config.get("tls_certificate_path")) self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = os.path.abspath(config.get("tls_private_key_path")) self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
self._original_tls_fingerprints = config["tls_fingerprints"] self._original_tls_fingerprints = config["tls_fingerprints"]
self.tls_fingerprints = list(self._original_tls_fingerprints) self.tls_fingerprints = list(self._original_tls_fingerprints)
self.no_tls = config.get("no_tls", False) self.no_tls = config.get("no_tls", False)
@ -94,6 +95,16 @@ class TlsConfig(Config):
""" """
self.tls_certificate = self.read_tls_certificate(self.tls_certificate_file) self.tls_certificate = self.read_tls_certificate(self.tls_certificate_file)
# Check if it is self-signed, and issue a warning if so.
if self.tls_certificate.get_issuer() == self.tls_certificate.get_subject():
warnings.warn(
(
"Self-signed TLS certificates will not be accepted by Synapse 1.0. "
"Please either provide a valid certificate, or use Synapse's ACME "
"support to provision one."
)
)
if not self.no_tls: if not self.no_tls:
self.tls_private_key = self.read_tls_private_key(self.tls_private_key_file) self.tls_private_key = self.read_tls_private_key(self.tls_private_key_file)
@ -118,10 +129,11 @@ class TlsConfig(Config):
return ( return (
"""\ """\
# PEM encoded X509 certificate for TLS. # PEM encoded X509 certificate for TLS.
# You can replace the self-signed certificate that synapse # This certificate, as of Synapse 1.0, will need to be a valid
# autogenerates on launch with your own SSL certificate + key pair # and verifiable certificate, with a root that is available in
# if you like. Any required intermediary certificates can be # the root store of other servers you wish to federate to. Any
# appended after the primary certificate in hierarchical order. # required intermediary certificates can be appended after the
# primary certificate in hierarchical order.
tls_certificate_path: "%(tls_certificate_path)s" tls_certificate_path: "%(tls_certificate_path)s"
# PEM encoded private key for TLS # PEM encoded private key for TLS
@ -183,40 +195,3 @@ class TlsConfig(Config):
def read_tls_private_key(self, private_key_path): def read_tls_private_key(self, private_key_path):
private_key_pem = self.read_file(private_key_path, "tls_private_key") private_key_pem = self.read_file(private_key_path, "tls_private_key")
return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem) return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem)
def generate_files(self, config):
tls_certificate_path = config["tls_certificate_path"]
tls_private_key_path = config["tls_private_key_path"]
if not self.path_exists(tls_private_key_path):
with open(tls_private_key_path, "wb") as private_key_file:
tls_private_key = crypto.PKey()
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
private_key_pem = crypto.dump_privatekey(
crypto.FILETYPE_PEM, tls_private_key
)
private_key_file.write(private_key_pem)
else:
with open(tls_private_key_path) as private_key_file:
private_key_pem = private_key_file.read()
tls_private_key = crypto.load_privatekey(
crypto.FILETYPE_PEM, private_key_pem
)
if not self.path_exists(tls_certificate_path):
with open(tls_certificate_path, "wb") as certificate_file:
cert = crypto.X509()
subject = cert.get_subject()
subject.CN = config["server_name"]
cert.set_serial_number(1000)
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60)
cert.set_issuer(cert.get_subject())
cert.set_pubkey(tls_private_key)
cert.sign(tls_private_key, 'sha256')
cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
certificate_file.write(cert_pem)

View File

@ -131,12 +131,12 @@ def compute_event_signature(event_dict, signature_name, signing_key):
return redact_json["signatures"] return redact_json["signatures"]
def add_hashes_and_signatures(event, signature_name, signing_key, def add_hashes_and_signatures(event_dict, signature_name, signing_key,
hash_algorithm=hashlib.sha256): hash_algorithm=hashlib.sha256):
"""Add content hash and sign the event """Add content hash and sign the event
Args: Args:
event_dict (EventBuilder): The event to add hashes to and sign event_dict (dict): The event to add hashes to and sign
signature_name (str): The name of the entity signing the event signature_name (str): The name of the entity signing the event
(typically the server's hostname). (typically the server's hostname).
signing_key (syutil.crypto.SigningKey): The key to sign with signing_key (syutil.crypto.SigningKey): The key to sign with
@ -144,16 +144,12 @@ def add_hashes_and_signatures(event, signature_name, signing_key,
to hash the event to hash the event
""" """
name, digest = compute_content_hash( name, digest = compute_content_hash(event_dict, hash_algorithm=hash_algorithm)
event.get_pdu_json(), hash_algorithm=hash_algorithm,
)
if not hasattr(event, "hashes"): event_dict.setdefault("hashes", {})[name] = encode_base64(digest)
event.hashes = {}
event.hashes[name] = encode_base64(digest)
event.signatures = compute_event_signature( event_dict["signatures"] = compute_event_signature(
event.get_pdu_json(), event_dict,
signature_name=signature_name, signature_name=signature_name,
signing_key=signing_key, signing_key=signing_key,
) )

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,11 +19,9 @@ from distutils.util import strtobool
import six import six
from synapse.api.constants import ( from unpaddedbase64 import encode_base64
KNOWN_EVENT_FORMAT_VERSIONS,
KNOWN_ROOM_VERSIONS, from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventFormatVersions, RoomVersions
EventFormatVersions,
)
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze from synapse.util.frozenutils import freeze
@ -240,16 +239,6 @@ class FrozenEvent(EventBase):
rejected_reason=rejected_reason, rejected_reason=rejected_reason,
) )
@staticmethod
def from_event(event):
e = FrozenEvent(
event.get_pdu_json()
)
e.internal_metadata = event.internal_metadata
return e
def __str__(self): def __str__(self):
return self.__repr__() return self.__repr__()
@ -261,6 +250,85 @@ class FrozenEvent(EventBase):
) )
class FrozenEventV2(EventBase):
format_version = EventFormatVersions.V2 # All events of this type are V2
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a
# copy.deepcopy
signatures = {
name: {sig_id: sig for sig_id, sig in sigs.items()}
for name, sigs in event_dict.pop("signatures", {}).items()
}
assert "event_id" not in event_dict
unsigned = dict(event_dict.pop("unsigned", {}))
# We intern these strings because they turn up a lot (especially when
# caching).
event_dict = intern_dict(event_dict)
if USE_FROZEN_DICTS:
frozen_dict = freeze(event_dict)
else:
frozen_dict = event_dict
self._event_id = None
self.type = event_dict["type"]
if "state_key" in event_dict:
self.state_key = event_dict["state_key"]
super(FrozenEventV2, self).__init__(
frozen_dict,
signatures=signatures,
unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict,
rejected_reason=rejected_reason,
)
@property
def event_id(self):
# We have to import this here as otherwise we get an import loop which
# is hard to break.
from synapse.crypto.event_signing import compute_event_reference_hash
if self._event_id:
return self._event_id
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
return self._event_id
def prev_event_ids(self):
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
list[str]: The list of event IDs of this event's prev_events
"""
return self.prev_events
def auth_event_ids(self):
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
list[str]: The list of event IDs of this event's auth_events
"""
return self.auth_events
def __str__(self):
return self.__repr__()
def __repr__(self):
return "<FrozenEventV2 event_id='%s', type='%s', state_key='%s'>" % (
self.event_id,
self.get("type", None),
self.get("state_key", None),
)
def room_version_to_event_format(room_version): def room_version_to_event_format(room_version):
"""Converts a room version string to the event format """Converts a room version string to the event format
@ -274,7 +342,13 @@ def room_version_to_event_format(room_version):
# We should have already checked version, so this should not happen # We should have already checked version, so this should not happen
raise RuntimeError("Unrecognized room version %s" % (room_version,)) raise RuntimeError("Unrecognized room version %s" % (room_version,))
if room_version in (
RoomVersions.V1, RoomVersions.V2, RoomVersions.VDH_TEST,
RoomVersions.STATE_V2_TEST,
):
return EventFormatVersions.V1 return EventFormatVersions.V1
else:
raise RuntimeError("Unrecognized room version %s" % (room_version,))
def event_type_from_format_version(format_version): def event_type_from_format_version(format_version):
@ -288,8 +362,12 @@ def event_type_from_format_version(format_version):
type: A type that can be initialized as per the initializer of type: A type that can be initialized as per the initializer of
`FrozenEvent` `FrozenEvent`
""" """
if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
if format_version == EventFormatVersions.V1:
return FrozenEvent
elif format_version == EventFormatVersions.V2:
return FrozenEventV2
else:
raise Exception( raise Exception(
"No event format %r" % (format_version,) "No event format %r" % (format_version,)
) )
return FrozenEvent

View File

@ -13,78 +13,161 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy import attr
from synapse.api.constants import RoomVersions from twisted.internet import defer
from synapse.api.constants import (
KNOWN_EVENT_FORMAT_VERSIONS,
KNOWN_ROOM_VERSIONS,
MAX_DEPTH,
EventFormatVersions,
)
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.types import EventID from synapse.types import EventID
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from . import EventBase, FrozenEvent, _event_dict_property from . import (
_EventInternalMetadata,
event_type_from_format_version,
room_version_to_event_format,
)
def get_event_builder(room_version, key_values={}, internal_metadata_dict={}): @attr.s(slots=True, cmp=False, frozen=True)
"""Generate an event builder appropriate for the given room version class EventBuilder(object):
"""A format independent event builder used to build up the event content
before signing the event.
(Note that while objects of this class are frozen, the
content/unsigned/internal_metadata fields are still mutable)
Attributes:
format_version (int): Event format version
room_id (str)
type (str)
sender (str)
content (dict)
unsigned (dict)
internal_metadata (_EventInternalMetadata)
_state (StateHandler)
_auth (synapse.api.Auth)
_store (DataStore)
_clock (Clock)
_hostname (str): The hostname of the server creating the event
_signing_key: The signing key to use to sign the event as the server
"""
_state = attr.ib()
_auth = attr.ib()
_store = attr.ib()
_clock = attr.ib()
_hostname = attr.ib()
_signing_key = attr.ib()
format_version = attr.ib()
room_id = attr.ib()
type = attr.ib()
sender = attr.ib()
content = attr.ib(default=attr.Factory(dict))
unsigned = attr.ib(default=attr.Factory(dict))
# These only exist on a subset of events, so they raise AttributeError if
# someone tries to get them when they don't exist.
_state_key = attr.ib(default=None)
_redacts = attr.ib(default=None)
internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({})))
@property
def state_key(self):
if self._state_key is not None:
return self._state_key
raise AttributeError("state_key")
def is_state(self):
return self._state_key is not None
@defer.inlineCallbacks
def build(self, prev_event_ids):
"""Transform into a fully signed and hashed event
Args: Args:
room_version (str): Version of the room that we're creating an prev_event_ids (list[str]): The event IDs to use as the prev events
event builder for
key_values (dict): Fields used as the basis of the new event
internal_metadata_dict (dict): Used to create the `_EventInternalMetadata`
object.
Returns: Returns:
EventBuilder Deferred[FrozenEvent]
""" """
if room_version in {
RoomVersions.V1, state_ids = yield self._state.get_current_state_ids(
RoomVersions.V2, self.room_id, prev_event_ids,
RoomVersions.STATE_V2_TEST, )
}: auth_ids = yield self._auth.compute_auth_events(
return EventBuilder(key_values, internal_metadata_dict) self, state_ids,
)
if self.format_version == EventFormatVersions.V1:
auth_events = yield self._store.add_event_hashes(auth_ids)
prev_events = yield self._store.add_event_hashes(prev_event_ids)
else: else:
raise Exception( auth_events = auth_ids
"No event format defined for version %r" % (room_version,) prev_events = prev_event_ids
old_depth = yield self._store.get_max_depth_of(
prev_event_ids,
) )
depth = old_depth + 1
# we cap depth of generated events, to ensure that they are not
# rejected by other servers (and so that they can be persisted in
# the db)
depth = min(depth, MAX_DEPTH)
class EventBuilder(EventBase): event_dict = {
def __init__(self, key_values={}, internal_metadata_dict={}): "auth_events": auth_events,
signatures = copy.deepcopy(key_values.pop("signatures", {})) "prev_events": prev_events,
unsigned = copy.deepcopy(key_values.pop("unsigned", {})) "type": self.type,
"room_id": self.room_id,
"sender": self.sender,
"content": self.content,
"unsigned": self.unsigned,
"depth": depth,
"prev_state": [],
}
super(EventBuilder, self).__init__( if self.is_state():
key_values, event_dict["state_key"] = self._state_key
signatures=signatures,
unsigned=unsigned, if self._redacts is not None:
internal_metadata_dict=internal_metadata_dict, event_dict["redacts"] = self._redacts
defer.returnValue(
create_local_event_from_event_dict(
clock=self._clock,
hostname=self._hostname,
signing_key=self._signing_key,
format_version=self.format_version,
event_dict=event_dict,
internal_metadata_dict=self.internal_metadata.get_dict(),
)
) )
event_id = _event_dict_property("event_id")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
def build(self):
return FrozenEvent.from_event(self)
class EventBuilderFactory(object): class EventBuilderFactory(object):
def __init__(self, clock, hostname): def __init__(self, hs):
self.clock = clock self.clock = hs.get_clock()
self.hostname = hostname self.hostname = hs.hostname
self.signing_key = hs.config.signing_key[0]
self.event_id_count = 0 self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
def create_event_id(self): def new(self, room_version, key_values):
i = str(self.event_id_count)
self.event_id_count += 1
local_part = str(int(self.clock.time())) + i + random_string(5)
e_id = EventID(local_part, self.hostname)
return e_id.to_string()
def new(self, room_version, key_values={}):
"""Generate an event builder appropriate for the given room version """Generate an event builder appropriate for the given room version
Args: Args:
@ -97,26 +180,103 @@ class EventBuilderFactory(object):
""" """
# There's currently only the one event version defined # There's currently only the one event version defined
if room_version not in { if room_version not in KNOWN_ROOM_VERSIONS:
RoomVersions.V1,
RoomVersions.V2,
RoomVersions.STATE_V2_TEST,
}:
raise Exception( raise Exception(
"No event format defined for version %r" % (room_version,) "No event format defined for version %r" % (room_version,)
) )
key_values["event_id"] = self.create_event_id() return EventBuilder(
store=self.store,
state=self.state,
auth=self.auth,
clock=self.clock,
hostname=self.hostname,
signing_key=self.signing_key,
format_version=room_version_to_event_format(room_version),
type=key_values["type"],
state_key=key_values.get("state_key"),
room_id=key_values["room_id"],
sender=key_values["sender"],
content=key_values.get("content", {}),
unsigned=key_values.get("unsigned", {}),
redacts=key_values.get("redacts", None),
)
time_now = int(self.clock.time_msec())
key_values.setdefault("origin", self.hostname) def create_local_event_from_event_dict(clock, hostname, signing_key,
key_values.setdefault("origin_server_ts", time_now) format_version, event_dict,
internal_metadata_dict=None):
"""Takes a fully formed event dict, ensuring that fields like `origin`
and `origin_server_ts` have correct values for a locally produced event,
then signs and hashes it.
key_values.setdefault("unsigned", {}) Args:
age = key_values["unsigned"].pop("age", 0) clock (Clock)
key_values["unsigned"].setdefault("age_ts", time_now - age) hostname (str)
signing_key
format_version (int)
event_dict (dict)
internal_metadata_dict (dict|None)
key_values["signatures"] = {} Returns:
FrozenEvent
"""
return EventBuilder(key_values=key_values,) # There's currently only the one event version defined
if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
raise Exception(
"No event format defined for version %r" % (format_version,)
)
if internal_metadata_dict is None:
internal_metadata_dict = {}
time_now = int(clock.time_msec())
if format_version == EventFormatVersions.V1:
event_dict["event_id"] = _create_event_id(clock, hostname)
event_dict["origin"] = hostname
event_dict["origin_server_ts"] = time_now
event_dict.setdefault("unsigned", {})
age = event_dict["unsigned"].pop("age", 0)
event_dict["unsigned"].setdefault("age_ts", time_now - age)
event_dict.setdefault("signatures", {})
add_hashes_and_signatures(
event_dict,
hostname,
signing_key,
)
return event_type_from_format_version(format_version)(
event_dict, internal_metadata_dict=internal_metadata_dict,
)
# A counter used when generating new event IDs
_event_id_counter = 0
def _create_event_id(clock, hostname):
"""Create a new event ID
Args:
clock (Clock)
hostname (str): The server name for the event ID
Returns:
str
"""
global _event_id_counter
i = str(_event_id_counter)
_event_id_counter += 1
local_part = str(int(clock.time())) + i + random_string(5)
e_id = EventID(local_part, hostname)
return e_id.to_string()

View File

@ -267,6 +267,7 @@ def serialize_event(e, time_now_ms, as_client_event=True,
Returns: Returns:
dict dict
""" """
# FIXME(erikj): To handle the case of presence events and the like # FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase): if not isinstance(e, EventBase):
return e return e
@ -276,6 +277,8 @@ def serialize_event(e, time_now_ms, as_client_event=True,
# Should this strip out None's? # Should this strip out None's?
d = {k: v for k, v in e.get_dict().items()} d = {k: v for k, v in e.get_dict().items()}
d["event_id"] = e.event_id
if "age_ts" in d["unsigned"]: if "age_ts" in d["unsigned"]:
d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"] d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"]
del d["unsigned"]["age_ts"] del d["unsigned"]["age_ts"]

View File

@ -37,8 +37,7 @@ from synapse.api.errors import (
HttpResponseException, HttpResponseException,
SynapseError, SynapseError,
) )
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events import builder, room_version_to_event_format
from synapse.events import room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -72,7 +71,8 @@ class FederationClient(FederationBase):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client() self.transport_layer = hs.get_federation_transport_client()
self.event_builder_factory = hs.get_event_builder_factory() self.hostname = hs.hostname
self.signing_key = hs.config.signing_key[0]
self._get_pdu_cache = ExpiringCache( self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache", cache_name="get_pdu_cache",
@ -608,18 +608,10 @@ class FederationClient(FederationBase):
if "prev_state" not in pdu_dict: if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = [] pdu_dict["prev_state"] = []
# Strip off the fields that we want to clobber. ev = builder.create_local_event_from_event_dict(
pdu_dict.pop("origin", None) self._clock, self.hostname, self.signing_key,
pdu_dict.pop("origin_server_ts", None) format_version=event_format, event_dict=pdu_dict,
pdu_dict.pop("unsigned", None)
builder = self.event_builder_factory.new(room_version, pdu_dict)
add_hashes_and_signatures(
builder,
self.hs.hostname,
self.hs.config.signing_key[0]
) )
ev = builder.build()
defer.returnValue( defer.returnValue(
(destination, ev, event_format) (destination, ev, event_format)

View File

@ -322,7 +322,7 @@ class FederationServer(FederationBase):
if self.hs.is_mine_id(event.event_id): if self.hs.is_mine_id(event.event_id):
event.signatures.update( event.signatures.update(
compute_event_signature( compute_event_signature(
event, event.get_pdu_json(),
self.hs.hostname, self.hs.hostname,
self.hs.config.signing_key[0] self.hs.config.signing_key[0]
) )

View File

@ -1300,7 +1300,7 @@ class FederationHandler(BaseHandler):
event.signatures.update( event.signatures.update(
compute_event_signature( compute_event_signature(
event, event.get_pdu_json(),
self.hs.hostname, self.hs.hostname,
self.hs.config.signing_key[0] self.hs.config.signing_key[0]
) )

View File

@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import succeed from twisted.internet.defer import succeed
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership, RoomVersions from synapse.api.constants import EventTypes, Membership, RoomVersions
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
@ -31,7 +31,6 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.api.urls import ConsentURIBuilder from synapse.api.urls import ConsentURIBuilder
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_event import ReplicationSendEventRestServlet
@ -545,40 +544,19 @@ class EventCreationHandler(object):
prev_events_and_hashes = \ prev_events_and_hashes = \
yield self.store.get_prev_events_for_room(builder.room_id) yield self.store.get_prev_events_for_room(builder.room_id)
if prev_events_and_hashes:
depth = max([d for _, _, d in prev_events_and_hashes]) + 1
# we cap depth of generated events, to ensure that they are not
# rejected by other servers (and so that they can be persisted in
# the db)
depth = min(depth, MAX_DEPTH)
else:
depth = 1
prev_events = [ prev_events = [
(event_id, prev_hashes) (event_id, prev_hashes)
for event_id, prev_hashes, _ in prev_events_and_hashes for event_id, prev_hashes, _ in prev_events_and_hashes
] ]
builder.prev_events = prev_events event = yield builder.build(
builder.depth = depth prev_event_ids=[p for p, _ in prev_events],
)
context = yield self.state.compute_event_context(builder) context = yield self.state.compute_event_context(event)
if requester: if requester:
context.app_service = requester.app_service context.app_service = requester.app_service
if builder.is_state(): self.validator.validate_new(event)
builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
)
yield self.auth.add_auth_events(builder, context)
signing_key = self.hs.config.signing_key[0]
add_hashes_and_signatures(
builder, self.server_name, signing_key
)
event = builder.build()
logger.debug( logger.debug(
"Created event %s", "Created event %s",

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging import logging
import attr import attr
@ -20,7 +21,7 @@ from zope.interface import implementer
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web.client import URI, Agent, HTTPConnectionPool from twisted.web.client import URI, Agent, HTTPConnectionPool, readBody
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent from twisted.web.iweb import IAgent
@ -43,13 +44,19 @@ class MatrixFederationAgent(object):
tls_client_options_factory (ClientTLSOptionsFactory|None): tls_client_options_factory (ClientTLSOptionsFactory|None):
factory to use for fetching client tls options, or none to disable TLS. factory to use for fetching client tls options, or none to disable TLS.
_well_known_tls_policy (IPolicyForHTTPS|None):
TLS policy to use for fetching .well-known files. None to use a default
(browser-like) implementation.
srv_resolver (SrvResolver|None): srv_resolver (SrvResolver|None):
SRVResolver impl to use for looking up SRV records. None to use a default SRVResolver impl to use for looking up SRV records. None to use a default
implementation. implementation.
""" """
def __init__( def __init__(
self, reactor, tls_client_options_factory, _srv_resolver=None, self, reactor, tls_client_options_factory,
_well_known_tls_policy=None,
_srv_resolver=None,
): ):
self._reactor = reactor self._reactor = reactor
self._tls_client_options_factory = tls_client_options_factory self._tls_client_options_factory = tls_client_options_factory
@ -62,6 +69,14 @@ class MatrixFederationAgent(object):
self._pool.maxPersistentPerHost = 5 self._pool.maxPersistentPerHost = 5
self._pool.cachedConnectionTimeout = 2 * 60 self._pool.cachedConnectionTimeout = 2 * 60
agent_args = {}
if _well_known_tls_policy is not None:
# the param is called 'contextFactory', but actually passing a
# contextfactory is deprecated, and it expects an IPolicyForHTTPS.
agent_args['contextFactory'] = _well_known_tls_policy
_well_known_agent = Agent(self._reactor, pool=self._pool, **agent_args)
self._well_known_agent = _well_known_agent
@defer.inlineCallbacks @defer.inlineCallbacks
def request(self, method, uri, headers=None, bodyProducer=None): def request(self, method, uri, headers=None, bodyProducer=None):
""" """
@ -114,7 +129,11 @@ class MatrixFederationAgent(object):
class EndpointFactory(object): class EndpointFactory(object):
@staticmethod @staticmethod
def endpointForURI(_uri): def endpointForURI(_uri):
logger.info("Connecting to %s:%s", res.target_host, res.target_port) logger.info(
"Connecting to %s:%i",
res.target_host.decode("ascii"),
res.target_port,
)
ep = HostnameEndpoint(self._reactor, res.target_host, res.target_port) ep = HostnameEndpoint(self._reactor, res.target_host, res.target_port)
if tls_options is not None: if tls_options is not None:
ep = wrapClientTLS(tls_options, ep) ep = wrapClientTLS(tls_options, ep)
@ -127,7 +146,7 @@ class MatrixFederationAgent(object):
defer.returnValue(res) defer.returnValue(res)
@defer.inlineCallbacks @defer.inlineCallbacks
def _route_matrix_uri(self, parsed_uri): def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
"""Helper for `request`: determine the routing for a Matrix URI """Helper for `request`: determine the routing for a Matrix URI
Args: Args:
@ -135,6 +154,9 @@ class MatrixFederationAgent(object):
parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1 parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
if there is no explicit port given. if there is no explicit port given.
lookup_well_known (bool): True if we should look up the .well-known file if
there is no SRV record.
Returns: Returns:
Deferred[_RoutingResult] Deferred[_RoutingResult]
""" """
@ -169,6 +191,42 @@ class MatrixFederationAgent(object):
service_name = b"_matrix._tcp.%s" % (parsed_uri.host,) service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
server_list = yield self._srv_resolver.resolve_service(service_name) server_list = yield self._srv_resolver.resolve_service(service_name)
if not server_list and lookup_well_known:
# try a .well-known lookup
well_known_server = yield self._get_well_known(parsed_uri.host)
if well_known_server:
# if we found a .well-known, start again, but don't do another
# .well-known lookup.
# parse the server name in the .well-known response into host/port.
# (This code is lifted from twisted.web.client.URI.fromBytes).
if b':' in well_known_server:
well_known_host, well_known_port = well_known_server.rsplit(b':', 1)
try:
well_known_port = int(well_known_port)
except ValueError:
# the part after the colon could not be parsed as an int
# - we assume it is an IPv6 literal with no port (the closing
# ']' stops it being parsed as an int)
well_known_host, well_known_port = well_known_server, -1
else:
well_known_host, well_known_port = well_known_server, -1
new_uri = URI(
scheme=parsed_uri.scheme,
netloc=well_known_server,
host=well_known_host,
port=well_known_port,
path=parsed_uri.path,
params=parsed_uri.params,
query=parsed_uri.query,
fragment=parsed_uri.fragment,
)
res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
defer.returnValue(res)
if not server_list: if not server_list:
target_host = parsed_uri.host target_host = parsed_uri.host
port = 8448 port = 8448
@ -190,6 +248,47 @@ class MatrixFederationAgent(object):
target_port=port, target_port=port,
)) ))
@defer.inlineCallbacks
def _get_well_known(self, server_name):
"""Attempt to fetch and parse a .well-known file for the given server
Args:
server_name (bytes): name of the server, from the requested url
Returns:
Deferred[bytes|None]: either the new server name, from the .well-known, or
None if there was no .well-known file.
"""
# FIXME: add a cache
uri = b"https://%s/.well-known/matrix/server" % (server_name, )
uri_str = uri.decode("ascii")
logger.info("Fetching %s", uri_str)
try:
response = yield make_deferred_yieldable(
self._well_known_agent.request(b"GET", uri),
)
except Exception as e:
logger.info("Connection error fetching %s: %s", uri_str, e)
defer.returnValue(None)
body = yield make_deferred_yieldable(readBody(response))
if response.code != 200:
logger.info("Error response %i from %s", response.code, uri_str)
defer.returnValue(None)
try:
parsed_body = json.loads(body.decode('utf-8'))
logger.info("Response from .well-known: %s", parsed_body)
if not isinstance(parsed_body, dict):
raise Exception("not a dict")
if "m.server" not in parsed_body:
raise Exception("Missing key 'm.server'")
except Exception as e:
raise Exception("invalid .well-known response from %s: %s" % (uri_str, e,))
defer.returnValue(parsed_body["m.server"].encode("ascii"))
@attr.s @attr.s
class _RoutingResult(object): class _RoutingResult(object):

View File

@ -101,16 +101,7 @@ class ConsentResource(Resource):
"missing in config file.", "missing in config file.",
) )
# daemonize changes the cwd to /, so make the path absolute now. consent_template_directory = hs.config.user_consent_template_dir
consent_template_directory = path.abspath(
hs.config.user_consent_template_dir,
)
if not path.isdir(consent_template_directory):
raise ConfigError(
"Could not find template directory '%s'" % (
consent_template_directory,
),
)
loader = jinja2.FileSystemLoader(consent_template_directory) loader = jinja2.FileSystemLoader(consent_template_directory)
self._jinja_env = jinja2.Environment( self._jinja_env = jinja2.Environment(

View File

@ -355,10 +355,7 @@ class HomeServer(object):
return Keyring(self) return Keyring(self)
def build_event_builder_factory(self): def build_event_builder_factory(self):
return EventBuilderFactory( return EventBuilderFactory(self)
clock=self.get_clock(),
hostname=self.hostname,
)
def build_filtering(self): def build_filtering(self):
return Filtering(self) return Filtering(self)

View File

@ -125,6 +125,29 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
return dict(txn) return dict(txn)
@defer.inlineCallbacks
def get_max_depth_of(self, event_ids):
"""Returns the max depth of a set of event IDs
Args:
event_ids (list[str])
Returns
Deferred[int]
"""
rows = yield self._simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
retcols=("depth",),
desc="get_max_depth_of",
)
if not rows:
defer.returnValue(0)
else:
defer.returnValue(max(row["depth"] for row in rows))
def _get_oldest_events_in_room_txn(self, txn, room_id): def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn( return self._simple_select_onecol_txn(
txn, txn,

View File

@ -50,8 +50,6 @@ class ConfigGenerationTestCase(unittest.TestCase):
"homeserver.yaml", "homeserver.yaml",
"lemurs.win.log.config", "lemurs.win.log.config",
"lemurs.win.signing.key", "lemurs.win.signing.key",
"lemurs.win.tls.crt",
"lemurs.win.tls.key",
] ]
), ),
set(os.listdir(self.dir)), set(os.listdir(self.dir)),

75
tests/config/test_tls.py Normal file
View File

@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from synapse.config.tls import TlsConfig
from tests.unittest import TestCase
class TLSConfigTests(TestCase):
def test_warn_self_signed(self):
"""
Synapse will give a warning when it loads a self-signed certificate.
"""
config_dir = self.mktemp()
os.mkdir(config_dir)
with open(os.path.join(config_dir, "cert.pem"), 'w') as f:
f.write("""-----BEGIN CERTIFICATE-----
MIID6DCCAtACAws9CjANBgkqhkiG9w0BAQUFADCBtzELMAkGA1UEBhMCVFIxDzAN
BgNVBAgMBsOHb3J1bTEUMBIGA1UEBwwLQmHFn21ha8OnxLExEjAQBgNVBAMMCWxv
Y2FsaG9zdDEcMBoGA1UECgwTVHdpc3RlZCBNYXRyaXggTGFiczEkMCIGA1UECwwb
QXV0b21hdGVkIFRlc3RpbmcgQXV0aG9yaXR5MSkwJwYJKoZIhvcNAQkBFhpzZWN1
cml0eUB0d2lzdGVkbWF0cml4LmNvbTAgFw0xNzA3MTIxNDAxNTNaGA8yMTE3MDYx
ODE0MDE1M1owgbcxCzAJBgNVBAYTAlRSMQ8wDQYDVQQIDAbDh29ydW0xFDASBgNV
BAcMC0JhxZ9tYWvDp8SxMRIwEAYDVQQDDAlsb2NhbGhvc3QxHDAaBgNVBAoME1R3
aXN0ZWQgTWF0cml4IExhYnMxJDAiBgNVBAsMG0F1dG9tYXRlZCBUZXN0aW5nIEF1
dGhvcml0eTEpMCcGCSqGSIb3DQEJARYac2VjdXJpdHlAdHdpc3RlZG1hdHJpeC5j
b20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDwT6kbqtMUI0sMkx4h
I+L780dA59KfksZCqJGmOsMD6hte9EguasfkZzvCF3dk3NhwCjFSOvKx6rCwiteo
WtYkVfo+rSuVNmt7bEsOUDtuTcaxTzIFB+yHOYwAaoz3zQkyVW0c4pzioiLCGCmf
FLdiDBQGGp74tb+7a0V6kC3vMLFoM3L6QWq5uYRB5+xLzlPJ734ltyvfZHL3Us6p
cUbK+3WTWvb4ER0W2RqArAj6Bc/ERQKIAPFEiZi9bIYTwvBH27OKHRz+KoY/G8zY
+l+WZoJqDhupRAQAuh7O7V/y6bSP+KNxJRie9QkZvw1PSaGSXtGJI3WWdO12/Ulg
epJpAgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAJXEq5P9xwvP9aDkXIqzcD0L8sf8
ewlhlxTQdeqt2Nace0Yk18lIo2oj1t86Y8jNbpAnZJeI813Rr5M7FbHCXoRc/SZG
I8OtG1xGwcok53lyDuuUUDexnK4O5BkjKiVlNPg4HPim5Kuj2hRNFfNt/F2BVIlj
iZupikC5MT1LQaRwidkSNxCku1TfAyueiBwhLnFwTmIGNnhuDCutEVAD9kFmcJN2
SznugAcPk4doX2+rL+ila+ThqgPzIkwTUHtnmjI0TI6xsDUlXz5S3UyudrE2Qsfz
s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
-----END CERTIFICATE-----""")
config = {
"tls_certificate_path": os.path.join(config_dir, "cert.pem"),
"no_tls": True,
"tls_fingerprints": []
}
t = TlsConfig()
t.read_config(config)
t.read_certificate_from_disk()
warnings = self.flushWarnings()
self.assertEqual(len(warnings), 1)
self.assertEqual(
warnings[0]["message"],
(
"Self-signed TLS certificates will not be accepted by "
"Synapse 1.0. Please either provide a valid certificate, "
"or use Synapse's ACME support to provision one."
)
)

View File

@ -18,7 +18,7 @@ import nacl.signing
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.builder import EventBuilder from synapse.events import FrozenEvent
from tests import unittest from tests import unittest
@ -40,8 +40,7 @@ class EventSigningTestCase(unittest.TestCase):
self.signing_key.version = KEY_VER self.signing_key.version = KEY_VER
def test_sign_minimal(self): def test_sign_minimal(self):
builder = EventBuilder( event_dict = {
{
'event_id': "$0:domain", 'event_id': "$0:domain",
'origin': "domain", 'origin': "domain",
'origin_server_ts': 1000000, 'origin_server_ts': 1000000,
@ -49,11 +48,10 @@ class EventSigningTestCase(unittest.TestCase):
'type': "X", 'type': "X",
'unsigned': {'age_ts': 1000000}, 'unsigned': {'age_ts': 1000000},
} }
)
add_hashes_and_signatures(builder, HOSTNAME, self.signing_key) add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key)
event = builder.build() event = FrozenEvent(event_dict)
self.assertTrue(hasattr(event, 'hashes')) self.assertTrue(hasattr(event, 'hashes'))
self.assertIn('sha256', event.hashes) self.assertIn('sha256', event.hashes)
@ -71,8 +69,7 @@ class EventSigningTestCase(unittest.TestCase):
) )
def test_sign_message(self): def test_sign_message(self):
builder = EventBuilder( event_dict = {
{
'content': {'body': "Here is the message content"}, 'content': {'body': "Here is the message content"},
'event_id': "$0:domain", 'event_id': "$0:domain",
'origin': "domain", 'origin': "domain",
@ -83,11 +80,10 @@ class EventSigningTestCase(unittest.TestCase):
'signatures': {}, 'signatures': {},
'unsigned': {'age_ts': 1000000}, 'unsigned': {'age_ts': 1000000},
} }
)
add_hashes_and_signatures(builder, HOSTNAME, self.signing_key) add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key)
event = builder.build() event = FrozenEvent(event_dict)
self.assertTrue(hasattr(event, 'hashes')) self.assertTrue(hasattr(event, 'hashes'))
self.assertIn('sha256', event.hashes) self.assertIn('sha256', event.hashes)

View File

@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
from OpenSSL import SSL
def get_test_cert_file():
"""get the path to the test cert"""
# the cert file itself is made with:
#
# openssl req -x509 -newkey rsa:4096 -keyout server.pem -out server.pem -days 36500 \
# -nodes -subj '/CN=testserv'
return os.path.join(
os.path.dirname(__file__),
'server.pem',
)
class ServerTLSContext(object):
"""A TLS Context which presents our test cert."""
def __init__(self):
self.filename = get_test_cert_file()
def getContext(self):
ctx = SSL.Context(SSL.TLSv1_METHOD)
ctx.use_certificate_file(self.filename)
ctx.use_privatekey_file(self.filename)
return ctx

View File

@ -17,18 +17,21 @@ import logging
from mock import Mock from mock import Mock
import treq import treq
from zope.interface import implementer
from twisted.internet import defer from twisted.internet import defer
from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.test.ssl_helpers import ServerTLSContext
from twisted.web.http import HTTPChannel from twisted.web.http import HTTPChannel
from twisted.web.iweb import IPolicyForHTTPS
from synapse.crypto.context_factory import ClientTLSOptionsFactory from synapse.crypto.context_factory import ClientTLSOptionsFactory
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.http.federation.srv_resolver import Server from synapse.http.federation.srv_resolver import Server
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from tests.http import ServerTLSContext
from tests.server import FakeTransport, ThreadedMemoryReactorClock from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.unittest import TestCase from tests.unittest import TestCase
@ -44,6 +47,7 @@ class MatrixFederationAgentTests(TestCase):
self.agent = MatrixFederationAgent( self.agent = MatrixFederationAgent(
reactor=self.reactor, reactor=self.reactor,
tls_client_options_factory=ClientTLSOptionsFactory(None), tls_client_options_factory=ClientTLSOptionsFactory(None),
_well_known_tls_policy=TrustingTLSPolicyForHTTPS(),
_srv_resolver=self.mock_resolver, _srv_resolver=self.mock_resolver,
) )
@ -65,10 +69,14 @@ class MatrixFederationAgentTests(TestCase):
# Normally this would be done by the TCP socket code in Twisted, but we are # Normally this would be done by the TCP socket code in Twisted, but we are
# stubbing that out here. # stubbing that out here.
client_protocol = client_factory.buildProtocol(None) client_protocol = client_factory.buildProtocol(None)
client_protocol.makeConnection(FakeTransport(server_tls_protocol, self.reactor)) client_protocol.makeConnection(
FakeTransport(server_tls_protocol, self.reactor, client_protocol),
)
# tell the server tls protocol to send its stuff back to the client, too # tell the server tls protocol to send its stuff back to the client, too
server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor)) server_tls_protocol.makeConnection(
FakeTransport(client_protocol, self.reactor, server_tls_protocol),
)
# give the reactor a pump to get the TLS juices flowing. # give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,)) self.reactor.pump((0.1,))
@ -101,9 +109,48 @@ class MatrixFederationAgentTests(TestCase):
try: try:
fetch_res = yield fetch_d fetch_res = yield fetch_d
defer.returnValue(fetch_res) defer.returnValue(fetch_res)
except Exception as e:
logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e)
raise
finally: finally:
_check_logcontext(context) _check_logcontext(context)
def _handle_well_known_connection(self, client_factory, expected_sni, target_server):
"""Handle an outgoing HTTPs connection: wire it up to a server, check that the
request is for a .well-known, and send the response.
Args:
client_factory (IProtocolFactory): outgoing connection
expected_sni (bytes): SNI that we expect the outgoing connection to send
target_server (bytes): target server that we should redirect to in the
.well-known response.
"""
# make the connection for .well-known
well_known_server = self._make_connection(
client_factory,
expected_sni=expected_sni,
)
# check the .well-known request and send a response
self.assertEqual(len(well_known_server.requests), 1)
request = well_known_server.requests[0]
self._send_well_known_response(request, target_server)
def _send_well_known_response(self, request, target_server):
"""Check that an incoming request looks like a valid .well-known request, and
send back the response.
"""
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/.well-known/matrix/server')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'testserv'],
)
# send back a response
request.write(b'{ "m.server": "%s" }' % (target_server,))
request.finish()
self.reactor.pump((0.1, ))
def test_get(self): def test_get(self):
""" """
happy-path test of a GET request with an explicit port happy-path test of a GET request with an explicit port
@ -283,9 +330,9 @@ class MatrixFederationAgentTests(TestCase):
self.reactor.pump((0.1,)) self.reactor.pump((0.1,))
self.successResultOf(test_d) self.successResultOf(test_d)
def test_get_hostname_no_srv(self): def test_get_no_srv_no_well_known(self):
""" """
Test the behaviour when the server name has no port, and no SRV record Test the behaviour when the server name has no port, no SRV, and no well-known
""" """
self.mock_resolver.resolve_service.side_effect = lambda _: [] self.mock_resolver.resolve_service.side_effect = lambda _: []
@ -300,11 +347,24 @@ class MatrixFederationAgentTests(TestCase):
b"_matrix._tcp.testserv", b"_matrix._tcp.testserv",
) )
# Make sure treq is trying to connect # there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1) self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0] (host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4') self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 443)
# fonx the connection
client_factory.clientConnectionFailed(None, Exception("nope"))
# attemptdelay on the hostnameendpoint is 0.3, so takes that long before the
# .well-known request fails.
self.reactor.pump((0.4,))
# we should fall back to a direct connection
self.assertEqual(len(clients), 2)
(host, port, client_factory, _timeout, _bindAddress) = clients[1]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 8448) self.assertEqual(port, 8448)
# make a test server, and wire up the client # make a test server, and wire up the client
@ -327,6 +387,67 @@ class MatrixFederationAgentTests(TestCase):
self.reactor.pump((0.1,)) self.reactor.pump((0.1,))
self.successResultOf(test_d) self.successResultOf(test_d)
def test_get_well_known(self):
"""Test the behaviour when the server name has no port and no SRV record, but
the .well-known redirects elsewhere
"""
self.mock_resolver.resolve_service.side_effect = lambda _: []
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
# Nothing happened yet
self.assertNoResult(test_d)
self.mock_resolver.resolve_service.assert_called_once_with(
b"_matrix._tcp.testserv",
)
self.mock_resolver.resolve_service.reset_mock()
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 443)
self._handle_well_known_connection(
client_factory, expected_sni=b"testserv", target_server=b"target-server",
)
# there should be another SRV lookup
self.mock_resolver.resolve_service.assert_called_once_with(
b"_matrix._tcp.target-server",
)
# now we should get a connection to the target server
self.assertEqual(len(clients), 2)
(host, port, client_factory, _timeout, _bindAddress) = clients[1]
self.assertEqual(host, '1::f')
self.assertEqual(port, 8448)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory,
expected_sni=b'target-server',
)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'target-server'],
)
# finish the request
request.finish()
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_get_hostname_srv(self): def test_get_hostname_srv(self):
""" """
Test the behaviour when there is a single SRV record Test the behaviour when there is a single SRV record
@ -372,6 +493,71 @@ class MatrixFederationAgentTests(TestCase):
self.reactor.pump((0.1,)) self.reactor.pump((0.1,))
self.successResultOf(test_d) self.successResultOf(test_d)
def test_get_well_known_srv(self):
"""Test the behaviour when the server name has no port and no SRV record, but
the .well-known redirects to a place where there is a SRV.
"""
self.mock_resolver.resolve_service.side_effect = lambda _: []
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["srvtarget"] = "5.6.7.8"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
# Nothing happened yet
self.assertNoResult(test_d)
self.mock_resolver.resolve_service.assert_called_once_with(
b"_matrix._tcp.testserv",
)
self.mock_resolver.resolve_service.reset_mock()
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 443)
self.mock_resolver.resolve_service.side_effect = lambda _: [
Server(host=b"srvtarget", port=8443),
]
self._handle_well_known_connection(
client_factory, expected_sni=b"testserv", target_server=b"target-server",
)
# there should be another SRV lookup
self.mock_resolver.resolve_service.assert_called_once_with(
b"_matrix._tcp.target-server",
)
# now we should get a connection to the target of the SRV record
self.assertEqual(len(clients), 2)
(host, port, client_factory, _timeout, _bindAddress) = clients[1]
self.assertEqual(host, '5.6.7.8')
self.assertEqual(port, 8443)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory,
expected_sni=b'target-server',
)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'target-server'],
)
# finish the request
request.finish()
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_idna_servername(self): def test_idna_servername(self):
"""test the behaviour when the server name has idna chars in""" """test the behaviour when the server name has idna chars in"""
@ -390,11 +576,25 @@ class MatrixFederationAgentTests(TestCase):
b"_matrix._tcp.xn--bcher-kva.com", b"_matrix._tcp.xn--bcher-kva.com",
) )
# Make sure treq is trying to connect # there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1) self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0] (host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4') self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 443)
# fonx the connection
client_factory.clientConnectionFailed(None, Exception("nope"))
# attemptdelay on the hostnameendpoint is 0.3, so takes that long before the
# .well-known request fails.
self.reactor.pump((0.4,))
# We should fall back to port 8448
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 2)
(host, port, client_factory, _timeout, _bindAddress) = clients[1]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 8448) self.assertEqual(port, 8448)
# make a test server, and wire up the client # make a test server, and wire up the client
@ -492,3 +692,11 @@ def _build_test_server():
def _log_request(request): def _log_request(request):
"""Implements Factory.log, which is expected by Request.finish""" """Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request) logger.info("Completed request %s", request)
@implementer(IPolicyForHTTPS)
class TrustingTLSPolicyForHTTPS(object):
"""An IPolicyForHTTPS which doesn't do any certificate verification"""
def creatorForNetloc(self, hostname, port):
certificateOptions = OpenSSLCertificateOptions()
return ClientTLSOptions(hostname, certificateOptions.getContext())

81
tests/http/server.pem Normal file
View File

@ -0,0 +1,81 @@
-----BEGIN PRIVATE KEY-----
MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQCgF43/3lAgJ+p0
x7Rn8UcL8a4fctvdkikvZrCngw96LkB34Evfq8YGWlOVjU+f9naUJLAKMatmAfEN
r+rMX4VOXmpTwuu6iLtqwreUrRFMESyrmvQxa15p+y85gkY0CFmXMblv6ORbxHTG
ncBGwST4WK4Poewcgt6jcISFCESTUKu1zc3cw1ANIDRyDLB5K44KwIe36dcKckyN
Kdtv4BJ+3fcIZIkPJH62zqCypgFF1oiFt40uJzClxgHdJZlKYpgkfnDTckw4Y/Mx
9k8BbE310KAzUNMV9H7I1eEolzrNr66FQj1eN64X/dqO8lTbwCqAd4diCT4sIUk0
0SVsAUjNd3g8j651hx+Qb1t8fuOjrny8dmeMxtUgIBHoQcpcj76R55Fs7KZ9uar0
8OFTyGIze51W1jG2K/7/5M1zxIqrA+7lsXu5OR81s7I+Ng/UUAhiHA/z+42/aiNa
qEuk6tqj3rHfLctnCbtZ+JrRNqSSwEi8F0lMA021ivEd2eJV+284OyJjhXOmKHrX
QADHrmS7Sh4syTZvRNm9n+qWID0KdDr2Sji/KnS3Enp44HDQ4xriT6/xhwEGsyuX
oH5aAkdLznulbWkHBbyx1SUQSTLpOqzaioF9m1vRrLsFvrkrY3D253mPJ5eU9HM/
dilduFcUgj4rz+6cdXUAh+KK/v95zwIDAQABAoICAFG5tJPaOa0ws0/KYx5s3YgL
aIhFalhCNSQtmCDrlwsYcXDA3/rfBchYdDL0YKGYgBBAal3J3WXFt/j0xThvyu2m
5UC9UPl4s7RckrsjXqEmY1d3UxGnbhtMT19cUdpeKN42VCP9EBaIw9Rg07dLAkSF
gNYaIx6q8F0fI4eGIPvTQtUcqur4CfWpaxyNvckdovV6M85/YXfDwbCOnacPDGIX
jfSK3i0MxGMuOHr6o8uzKR6aBUh6WStHWcw7VXXTvzdiFNbckmx3Gb93rf1b/LBw
QFfx+tBKcC62gKroCOzXso/0sL9YTVeSD/DJZOiJwSiz3Dj/3u1IUMbVvfTU8wSi
CYS7Z+jHxwSOCSSNTXm1wO/MtDsNKbI1+R0cohr/J9pOMQvrVh1+2zSDOFvXAQ1S
yvjn+uqdmijRoV2VEGVHd+34C+ci7eJGAhL/f92PohuuFR2shUETgGWzpACZSJwg
j1d90Hs81hj07vWRb+xCeDh00vimQngz9AD8vYvv/S4mqRGQ6TZdfjLoUwSTg0JD
6sQgRXX026gQhLhn687vLKZfHwzQPZkpQdxOR0dTZ/ho/RyGGRJXH4kN4cA2tPr+
AKYQ29YXGlEzGG7OqikaZcprNWG6UFgEpuXyBxCgp9r4ladZo3J+1Rhgus8ZYatd
uO98q3WEBmP6CZ2n32mBAoIBAQDS/c/ybFTos0YpGHakwdmSfj5OOQJto2y8ywfG
qDHwO0ebcpNnS1+MA+7XbKUQb/3Iq7iJljkkzJG2DIJ6rpKynYts1ViYpM7M/t0T
W3V1gvUcUL62iqkgws4pnpWmubFkqV31cPSHcfIIclnzeQ1aOEGsGHNAvhty0ciC
DnkJACbqApvopFLOR5f6UFTtKExE+hDH0WqgpsCAKJ1L4g6pBzZatI32/CN9JEVU
tDbxLV75hHlFFjUrG7nT1rPyr/gI8Ceh9/2xeXPfjJUR0PrG3U1nwLqUCZkvFzO6
XpN2+A+/v4v5xqMjKDKDFy1oq6SCMomwv/viw6wl/84TMbolAoIBAQDCPiMecnR8
REik6tqVzQO/uSe9ZHjz6J15t5xdwaI6HpSwLlIkQPkLTjyXtFpemK5DOYRxrJvQ
remfrZrN2qtLlb/DKpuGPWRsPOvWCrSuNEp48ivUehtclljrzxAFfy0sM+fWeJ48
nTnR+td9KNhjNtZixzWdAy/mE+jdaMsXVnk66L73Uz+2WsnvVMW2R6cpCR0F2eP/
B4zDWRqlT2w47sePAB81mFYSQLvPC6Xcgg1OqMubfiizJI49c8DO6Jt+FFYdsxhd
kG52Eqa/Net6rN3ueiS6yXL5TU3Y6g96bPA2KyNCypucGcddcBfqaiVx/o4AH6yT
NrdsrYtyvk/jAoIBAQDHUwKVeeRJJbvdbQAArCV4MI155n+1xhMe1AuXkCQFWGtQ
nlBE4D72jmyf1UKnIbW2Uwv15xY6/ouVWYIWlj9+QDmMaozVP7Uiko+WDuwLRNl8
k4dn+dzHV2HejbPBG2JLv3lFOx23q1zEwArcaXrExaq9Ayg2fKJ/uVHcFAIiD6Oz
pR1XDY4w1A/uaN+iYFSVQUyDCQLbnEz1hej73CaPZoHh9Pq83vxD5/UbjVjuRTeZ
L55FNzKpc/r89rNvTPBcuUwnxplDhYKDKVNWzn9rSXwrzTY2Tk8J3rh+k4RqevSd
6D47jH1n5Dy7/TRn0ueKHGZZtTUnyEUkbOJo3ayFAoIBAHKDyZaQqaX9Z8p6fwWj
yVsFoK0ih8BcWkLBAdmwZ6DWGJjJpjmjaG/G3ygc9s4gO1R8m12dAnuDnGE8KzDD
gwtbrKM2Alyg4wyA2hTlWOH/CAzH0RlCJ9Fs/d1/xJVJBeuyajLiB3/6vXTS6qnq
I7BSSxAPG8eGcn21LSsjNeB7ZZtaTgNnu/8ZBUYo9yrgkWc67TZe3/ChldYxOOlO
qqHh/BqNWtjxB4VZTp/g4RbgQVInZ2ozdXEv0v/dt0UEk29ANAjsZif7F3RayJ2f
/0TilzCaJ/9K9pKNhaClVRy7Dt8QjYg6BIWCGSw4ApF7pLnQ9gySn95mersCkVzD
YDsCggEAb0E/TORjQhKfNQvahyLfQFm151e+HIoqBqa4WFyfFxe/IJUaLH/JSSFw
VohbQqPdCmaAeuQ8ERL564DdkcY5BgKcax79fLLCOYP5bT11aQx6uFpfl2Dcm6Z9
QdCRI4jzPftsd5fxLNH1XtGyC4t6vTic4Pji2O71WgWzx0j5v4aeDY4sZQeFxqCV
/q7Ee8hem1Rn5RFHu14FV45RS4LAWl6wvf5pQtneSKzx8YL0GZIRRytOzdEfnGKr
FeUlAj5uL+5/p0ZEgM7gPsEBwdm8scF79qSUn8UWSoXNeIauF9D4BDg8RZcFFxka
KILVFsq3cQC+bEnoM4eVbjEQkGs1RQ==
-----END PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIIE/jCCAuagAwIBAgIJANFtVaGvJWZlMA0GCSqGSIb3DQEBCwUAMBMxETAPBgNV
BAMMCHRlc3RzZXJ2MCAXDTE5MDEyNzIyMDIzNloYDzIxMTkwMTAzMjIwMjM2WjAT
MREwDwYDVQQDDAh0ZXN0c2VydjCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoC
ggIBAKAXjf/eUCAn6nTHtGfxRwvxrh9y292SKS9msKeDD3ouQHfgS9+rxgZaU5WN
T5/2dpQksAoxq2YB8Q2v6sxfhU5ealPC67qIu2rCt5StEUwRLKua9DFrXmn7LzmC
RjQIWZcxuW/o5FvEdMadwEbBJPhYrg+h7ByC3qNwhIUIRJNQq7XNzdzDUA0gNHIM
sHkrjgrAh7fp1wpyTI0p22/gEn7d9whkiQ8kfrbOoLKmAUXWiIW3jS4nMKXGAd0l
mUpimCR+cNNyTDhj8zH2TwFsTfXQoDNQ0xX0fsjV4SiXOs2vroVCPV43rhf92o7y
VNvAKoB3h2IJPiwhSTTRJWwBSM13eDyPrnWHH5BvW3x+46OufLx2Z4zG1SAgEehB
ylyPvpHnkWzspn25qvTw4VPIYjN7nVbWMbYr/v/kzXPEiqsD7uWxe7k5HzWzsj42
D9RQCGIcD/P7jb9qI1qoS6Tq2qPesd8ty2cJu1n4mtE2pJLASLwXSUwDTbWK8R3Z
4lX7bzg7ImOFc6YoetdAAMeuZLtKHizJNm9E2b2f6pYgPQp0OvZKOL8qdLcSenjg
cNDjGuJPr/GHAQazK5egfloCR0vOe6VtaQcFvLHVJRBJMuk6rNqKgX2bW9GsuwW+
uStjcPbneY8nl5T0cz92KV24VxSCPivP7px1dQCH4or+/3nPAgMBAAGjUzBRMB0G
A1UdDgQWBBQcQZpzLzTk5KdS/Iz7sGCV7gTd/zAfBgNVHSMEGDAWgBQcQZpzLzTk
5KdS/Iz7sGCV7gTd/zAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IC
AQAr/Pgha57jqYsDDX1LyRrVdqoVBpLBeB7x/p9dKYm7S6tBTDFNMZ0SZyQP8VEG
7UoC9/OQ9nCdEMoR7ZKpQsmipwcIqpXHS6l4YOkf5EEq5jpMgvlEesHmBJJeJew/
FEPDl1bl8d0tSrmWaL3qepmwzA+2lwAAouWk2n+rLiP8CZ3jZeoTXFqYYrUlEqO9
fHMvuWqTV4KCSyNY+GWCrnHetulgKHlg+W2J1mZnrCKcBhWf9C2DesTJO+JldIeM
ornTFquSt21hZi+k3aySuMn2N3MWiNL8XsZVsAnPSs0zA+2fxjJkShls8Gc7cCvd
a6XrNC+PY6pONguo7rEU4HiwbvnawSTngFFglmH/ImdA/HkaAekW6o82aI8/UxFx
V9fFMO3iKDQdOrg77hI1bx9RlzKNZZinE2/Pu26fWd5d2zqDWCjl8ykGQRAfXgYN
H3BjgyXLl+ao5/pOUYYtzm3ruTXTgRcy5hhL6hVTYhSrf9vYh4LNIeXNKnZ78tyG
TX77/kU2qXhBGCFEUUMqUNV/+ITir2lmoxVjknt19M07aGr8C7SgYt6Rs+qDpMiy
JurgvRh8LpVq4pHx1efxzxCFmo58DMrG40I0+CF3y/niNpOb1gp2wAqByRiORkds
f0ytW6qZ0TpHbD6gOtQLYDnhx3ISuX+QYSekVwQUpffeWQ==
-----END CERTIFICATE-----

View File

@ -354,6 +354,11 @@ class FakeTransport(object):
:type: twisted.internet.interfaces.IReactorTime :type: twisted.internet.interfaces.IReactorTime
""" """
_protocol = attr.ib(default=None)
"""The Protocol which is producing data for this transport. Optional, but if set
will get called back for connectionLost() notifications etc.
"""
disconnecting = False disconnecting = False
buffer = attr.ib(default=b'') buffer = attr.ib(default=b'')
producer = attr.ib(default=None) producer = attr.ib(default=None)
@ -364,8 +369,12 @@ class FakeTransport(object):
def getHost(self): def getHost(self):
return None return None
def loseConnection(self): def loseConnection(self, reason=None):
logger.info("FakeTransport: loseConnection(%s)", reason)
if not self.disconnecting:
self.disconnecting = True self.disconnecting = True
if self._protocol:
self._protocol.connectionLost(reason)
def abortConnection(self): def abortConnection(self):
self.disconnecting = True self.disconnecting = True

View File

@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for running the unit tests
"""

View File

@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import twisted.logger
from synapse.util.logcontext import LoggingContextFilter
class ToTwistedHandler(logging.Handler):
"""logging handler which sends the logs to the twisted log"""
tx_log = twisted.logger.Logger()
def emit(self, record):
log_entry = self.format(record)
log_level = record.levelname.lower().replace('warning', 'warn')
self.tx_log.emit(
twisted.logger.LogLevel.levelWithName(log_level),
log_entry.replace("{", r"(").replace("}", r")"),
)
def setup_logging():
"""Configure the python logging appropriately for the tests.
(Logs will end up in _trial_temp.)
"""
root_logger = logging.getLogger()
log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s"
)
handler = ToTwistedHandler()
formatter = logging.Formatter(log_format)
handler.setFormatter(formatter)
handler.addFilter(LoggingContextFilter(request=""))
root_logger.addHandler(handler)
log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")
root_logger.setLevel(log_level)

View File

@ -166,7 +166,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_message(self, user_id, content=None): def inject_message(self, user_id, content=None):
if content is None: if content is None:
content = {"body": "testytest"} content = {"body": "testytest", "msgtype": "m.text"}
builder = self.event_builder_factory.new( builder = self.event_builder_factory.new(
RoomVersions.V1, RoomVersions.V1,
{ {

View File

@ -31,38 +31,14 @@ from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from synapse.util.logcontext import LoggingContext, LoggingContextFilter from synapse.util.logcontext import LoggingContext
from tests.server import get_clock, make_request, render, setup_test_homeserver from tests.server import get_clock, make_request, render, setup_test_homeserver
from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb from tests.utils import default_config, setupdb
setupdb() setupdb()
setup_logging()
# Set up putting Synapse's logs into Trial's.
rootLogger = logging.getLogger()
log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s"
)
class ToTwistedHandler(logging.Handler):
tx_log = twisted.logger.Logger()
def emit(self, record):
log_entry = self.format(record)
log_level = record.levelname.lower().replace('warning', 'warn')
self.tx_log.emit(
twisted.logger.LogLevel.levelWithName(log_level),
log_entry.replace("{", r"(").replace("}", r")"),
)
handler = ToTwistedHandler()
formatter = logging.Formatter(log_format)
handler.setFormatter(formatter)
handler.addFilter(LoggingContextFilter(request=""))
rootLogger.addHandler(handler)
def around(target): def around(target):
@ -96,7 +72,7 @@ class TestCase(unittest.TestCase):
method = getattr(self, methodName) method = getattr(self, methodName)
level = getattr(method, "loglevel", getattr(self, "loglevel", logging.WARNING)) level = getattr(method, "loglevel", getattr(self, "loglevel", None))
@around(self) @around(self)
def setUp(orig): def setUp(orig):
@ -114,7 +90,7 @@ class TestCase(unittest.TestCase):
) )
old_level = logging.getLogger().level old_level = logging.getLogger().level
if old_level != level: if level is not None and old_level != level:
@around(self) @around(self)
def tearDown(orig): def tearDown(orig):
@ -123,6 +99,7 @@ class TestCase(unittest.TestCase):
return ret return ret
logging.getLogger().setLevel(level) logging.getLogger().setLevel(level)
return orig() return orig()
@around(self) @around(self)