Merge branch 'develop' of github.com:matrix-org/synapse into anoa/public_rooms_federate

This commit is contained in:
Erik Johnston 2019-02-25 15:08:18 +00:00
commit 4b9e5076c4
384 changed files with 18682 additions and 6320 deletions

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2018-9 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -27,4 +27,4 @@ try:
except ImportError:
pass
__version__ = "0.33.8"
__version__ = "0.99.1.1"

View file

@ -35,6 +35,7 @@ def request_registration(
server_location,
shared_secret,
admin=False,
user_type=None,
requests=_requests,
_print=print,
exit=sys.exit,
@ -45,7 +46,7 @@ def request_registration(
# Get the nonce
r = requests.get(url, verify=False)
if r.status_code is not 200:
if r.status_code != 200:
_print("ERROR! Received %d %s" % (r.status_code, r.reason))
if 400 <= r.status_code < 500:
try:
@ -65,6 +66,9 @@ def request_registration(
mac.update(password.encode('utf8'))
mac.update(b"\x00")
mac.update(b"admin" if admin else b"notadmin")
if user_type:
mac.update(b"\x00")
mac.update(user_type.encode('utf8'))
mac = mac.hexdigest()
@ -74,12 +78,13 @@ def request_registration(
"password": password,
"mac": mac,
"admin": admin,
"user_type": user_type,
}
_print("Sending registration request...")
r = requests.post(url, json=data, verify=False)
if r.status_code is not 200:
if r.status_code != 200:
_print("ERROR! Received %d %s" % (r.status_code, r.reason))
if 400 <= r.status_code < 500:
try:
@ -91,7 +96,7 @@ def request_registration(
_print("Success!")
def register_new_user(user, password, server_location, shared_secret, admin):
def register_new_user(user, password, server_location, shared_secret, admin, user_type):
if not user:
try:
default_user = getpass.getuser()
@ -129,7 +134,8 @@ def register_new_user(user, password, server_location, shared_secret, admin):
else:
admin = False
request_registration(user, password, server_location, shared_secret, bool(admin))
request_registration(user, password, server_location, shared_secret,
bool(admin), user_type)
def main():
@ -154,6 +160,12 @@ def main():
default=None,
help="New password for user. Will prompt if omitted.",
)
parser.add_argument(
"-t",
"--user_type",
default=None,
help="User type as specified in synapse.api.constants.UserTypes",
)
admin_group = parser.add_mutually_exclusive_group()
admin_group.add_argument(
"-a",
@ -208,7 +220,8 @@ def main():
if args.admin or args.no_admin:
admin = args.admin
register_new_user(args.user, args.password, args.server_url, secret, admin)
register_new_user(args.user, args.password, args.server_url, secret,
admin, args.user_type)
if __name__ == "__main__":

View file

@ -65,7 +65,7 @@ class Auth(object):
register_cache("cache", "token_cache", self.token_cache)
@defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True):
def check_from_context(self, room_version, event, context, do_sig_check=True):
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.compute_auth_events(
event, prev_state_ids, for_verification=True,
@ -74,12 +74,16 @@ class Auth(object):
auth_events = {
(e.type, e.state_key): e for e in itervalues(auth_events)
}
self.check(event, auth_events=auth_events, do_sig_check=do_sig_check)
self.check(
room_version, event,
auth_events=auth_events, do_sig_check=do_sig_check,
)
def check(self, event, auth_events, do_sig_check=True):
def check(self, room_version, event, auth_events, do_sig_check=True):
""" Checks if this event is correctly authed.
Args:
room_version (str): version of the room
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
@ -88,7 +92,9 @@ class Auth(object):
True if the auth checks pass.
"""
with Measure(self.clock, "auth.check"):
event_auth.check(event, auth_events, do_sig_check=do_sig_check)
event_auth.check(
room_version, event, auth_events, do_sig_check=do_sig_check
)
@defer.inlineCallbacks
def check_joined_room(self, room_id, user_id, current_state=None):
@ -188,17 +194,33 @@ class Auth(object):
"""
# Can optionally look elsewhere in the request (e.g. headers)
try:
user_id, app_service = yield self._get_appservice_user_id(request)
if user_id:
request.authenticated_entity = user_id
defer.returnValue(
synapse.types.create_requester(user_id, app_service=app_service)
)
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders(
b"User-Agent",
default=[b""]
)[0].decode('ascii', 'surrogateescape')
access_token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
user_id, app_service = yield self._get_appservice_user_id(request)
if user_id:
request.authenticated_entity = user_id
if ip_addr and self.hs.config.track_appservice_user_ips:
yield self.store.insert_client_ip(
user_id=user_id,
access_token=access_token,
ip=ip_addr,
user_agent=user_agent,
device_id="dummy-device", # stubbed
)
defer.returnValue(
synapse.types.create_requester(user_id, app_service=app_service)
)
user_info = yield self.get_user_by_access_token(access_token, rights)
user = user_info["user"]
token_id = user_info["token_id"]
@ -208,11 +230,6 @@ class Auth(object):
# stubbed out.
device_id = user_info.get("device_id")
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders(
b"User-Agent",
default=[b""]
)[0].decode('ascii', 'surrogateescape')
if user and access_token and ip_addr:
yield self.store.insert_client_ip(
user_id=user.to_string(),
@ -289,20 +306,28 @@ class Auth(object):
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
if rights == "access":
# first look in the database
r = yield self._look_up_user_by_access_token(token)
if r:
defer.returnValue(r)
# otherwise it needs to be a valid macaroon
try:
user_id, guest = self._parse_and_validate_macaroon(token, rights)
except _InvalidMacaroonException:
# doesn't look like a macaroon: treat it as an opaque token which
# must be in the database.
# TODO: it would be nice to get rid of this, but apparently some
# people use access tokens which aren't macaroons
r = yield self._look_up_user_by_access_token(token)
defer.returnValue(r)
try:
user = UserID.from_string(user_id)
if guest:
if rights == "access":
if not guest:
# non-guest access tokens must be in the database
logger.warning("Unrecognised access token - not in store.")
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN,
)
# Guest access tokens are not stored in the database (there can
# only be one access token per guest, anyway).
#
@ -343,31 +368,15 @@ class Auth(object):
"device_id": None,
}
else:
# This codepath exists for several reasons:
# * so that we can actually return a token ID, which is used
# in some parts of the schema (where we probably ought to
# use device IDs instead)
# * the only way we currently have to invalidate an
# access_token is by removing it from the database, so we
# have to check here that it is still in the db
# * some attributes (notably device_id) aren't stored in the
# macaroon. They probably should be.
# TODO: build the dictionary from the macaroon once the
# above are fixed
ret = yield self._look_up_user_by_access_token(token)
if ret["user"] != user:
logger.error(
"Macaroon user (%s) != DB user (%s)",
user,
ret["user"]
)
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"User mismatch in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
raise RuntimeError("Unknown rights setting %s", rights)
defer.returnValue(ret)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
except (
_InvalidMacaroonException,
pymacaroons.exceptions.MacaroonException,
TypeError,
ValueError,
) as e:
logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
errcode=Codes.UNKNOWN_TOKEN
@ -497,11 +506,8 @@ class Auth(object):
def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token)
if not ret:
logger.warn("Unrecognised access token - not in store.")
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
)
defer.returnValue(None)
# we use ret.get() below because *lots* of unit tests stub out
# get_user_by_access_token in a way where it only returns a couple of
# the fields.
@ -544,17 +550,6 @@ class Auth(object):
"""
return self.store.is_server_admin(user)
@defer.inlineCallbacks
def add_auth_events(self, builder, context):
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_ids = yield self.compute_auth_events(builder, prev_state_ids)
auth_events_entries = yield self.store.add_event_hashes(
auth_ids
)
builder.auth_events = auth_events_entries
@defer.inlineCallbacks
def compute_auth_events(self, event, current_state_ids, for_verification=False):
if event.type == EventTypes.Create:
@ -571,7 +566,7 @@ class Auth(object):
key = (EventTypes.JoinRules, "", )
join_rule_event_id = current_state_ids.get(key)
key = (EventTypes.Member, event.user_id, )
key = (EventTypes.Member, event.sender, )
member_event_id = current_state_ids.get(key)
key = (EventTypes.Create, "", )
@ -621,7 +616,7 @@ class Auth(object):
defer.returnValue(auth_ids)
def check_redaction(self, event, auth_events):
def check_redaction(self, room_version, event, auth_events):
"""Check whether the event sender is allowed to redact the target event.
Returns:
@ -634,7 +629,7 @@ class Auth(object):
AuthError if the event sender is definitely not allowed to redact
the target event.
"""
return event_auth.check_redaction(event, auth_events)
return event_auth.check_redaction(room_version, event, auth_events)
@defer.inlineCallbacks
def check_can_change_room_list(self, room_id, user):
@ -791,9 +786,10 @@ class Auth(object):
threepid should never be set at the same time.
"""
# Never fail an auth check for the server notices users
# Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking
if user_id == self.hs.config.server_notices_mxid:
is_support = yield self.store.is_support_user(user_id)
if user_id == self.hs.config.server_notices_mxid or is_support:
return
if self.hs.config.hs_disabled:
@ -818,7 +814,9 @@ class Auth(object):
elif threepid:
# If the user does not exist yet, but is signing up with a
# reserved threepid then pass auth check
if is_threepid_reserved(self.hs.config, threepid):
if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid
):
return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()

View file

@ -51,6 +51,7 @@ class LoginType(object):
EMAIL_IDENTITY = u"m.login.email.identity"
MSISDN = u"m.login.msisdn"
RECAPTCHA = u"m.login.recaptcha"
TERMS = u"m.login.terms"
DUMMY = u"m.login.dummy"
# Only for C/S API v1
@ -61,15 +62,18 @@ class LoginType(object):
class EventTypes(object):
Member = "m.room.member"
Create = "m.room.create"
Tombstone = "m.room.tombstone"
JoinRules = "m.room.join_rules"
PowerLevels = "m.room.power_levels"
Aliases = "m.room.aliases"
Redaction = "m.room.redaction"
ThirdPartyInvite = "m.room.third_party_invite"
Encryption = "m.room.encryption"
RoomHistoryVisibility = "m.room.history_visibility"
CanonicalAlias = "m.room.canonical_alias"
RoomAvatar = "m.room.avatar"
RoomEncryption = "m.room.encryption"
GuestAccess = "m.room.guest_access"
# These are used for validation
@ -100,7 +104,14 @@ class ThirdPartyEntityKind(object):
class RoomVersions(object):
V1 = "1"
VDH_TEST = "vdh-test-version"
V2 = "2"
V3 = "3"
STATE_V2_TEST = "state-v2-test"
class RoomDisposition(object):
STABLE = "stable"
UNSTABLE = "unstable"
# the version we will give rooms which are created on this server
@ -108,7 +119,36 @@ DEFAULT_ROOM_VERSION = RoomVersions.V1
# vdh-test-version is a placeholder to get room versioning support working and tested
# until we have a working v2.
KNOWN_ROOM_VERSIONS = {RoomVersions.V1, RoomVersions.VDH_TEST}
KNOWN_ROOM_VERSIONS = {
RoomVersions.V1,
RoomVersions.V2,
RoomVersions.V3,
RoomVersions.STATE_V2_TEST,
RoomVersions.V3,
}
class EventFormatVersions(object):
"""This is an internal enum for tracking the version of the event format,
independently from the room version.
"""
V1 = 1
V2 = 2
KNOWN_EVENT_FORMAT_VERSIONS = {
EventFormatVersions.V1,
EventFormatVersions.V2,
}
ServerNoticeMsgType = "m.server_notice"
ServerNoticeLimitReached = "m.server_notice.usage_limit_reached"
class UserTypes(object):
"""Allows for user type specific behaviour. With the benefit of hindsight
'admin' and 'guest' users should also be UserTypes. Normal users are type None
"""
SUPPORT = "support"
ALL_USER_TYPES = (SUPPORT,)

View file

@ -348,6 +348,24 @@ class IncompatibleRoomVersionError(SynapseError):
)
class RequestSendFailed(RuntimeError):
"""Sending a HTTP request over federation failed due to not being able to
talk to the remote server for some reason.
This exception is used to differentiate "expected" errors that arise due to
networking (e.g. DNS failures, connection timeouts etc), versus unexpected
errors (like programming errors).
"""
def __init__(self, inner_exception, can_retry):
super(RequestSendFailed, self).__init__(
"Failed to send request: %s: %s" % (
type(inner_exception).__name__, inner_exception,
)
)
self.inner_exception = inner_exception
self.can_retry = can_retry
def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
""" Utility method for constructing an error response for client-server
interactions.

View file

@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from six import text_type
import jsonschema
from canonicaljson import json
from jsonschema import FormatChecker
@ -353,7 +355,7 @@ class Filter(object):
sender = event.user_id
room_id = None
ev_type = "m.presence"
is_url = False
contains_url = False
else:
sender = event.get("sender", None)
if not sender:
@ -368,13 +370,16 @@ class Filter(object):
room_id = event.get("room_id", None)
ev_type = event.get("type", None)
is_url = "url" in event.get("content", {})
content = event.get("content", {})
# check if there is a string url field in the content for filtering purposes
contains_url = isinstance(content.get("url"), text_type)
return self.check_fields(
room_id,
sender,
ev_type,
is_url,
contains_url,
)
def check_fields(self, room_id, sender, event_type, contains_url):
@ -439,6 +444,20 @@ class Filter(object):
def include_redundant_members(self):
return self.filter_json.get("include_redundant_members", False)
def with_room_ids(self, room_ids):
"""Returns a new filter with the given room IDs appended.
Args:
room_ids (iterable[unicode]): The room_ids to add
Returns:
filter: A new filter including the given rooms and the old
filter's rooms.
"""
newFilter = Filter(self.filter_json)
newFilter.rooms += room_ids
return newFilter
def _matches_wildcard(actual_value, filter_value):
if filter_value.endswith("*"):

View file

@ -24,11 +24,12 @@ from synapse.config import ConfigError
CLIENT_PREFIX = "/_matrix/client/api/v1"
CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha"
FEDERATION_PREFIX = "/_matrix/federation/v1"
FEDERATION_PREFIX = "/_matrix/federation"
FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1"
FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2"
STATIC_PREFIX = "/_matrix/static"
WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content"
SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/r0"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"

View file

@ -12,22 +12,38 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
from synapse import python_dependencies # noqa: E402
sys.dont_write_bytecode = True
logger = logging.getLogger(__name__)
try:
python_dependencies.check_requirements()
except python_dependencies.MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (str(e),),
"To install run:",
" pip install --upgrade --force \"%s\"" % (e.dependency,),
"",
])
sys.stderr.writelines(message)
except python_dependencies.DependencyException as e:
sys.stderr.writelines(e.message)
sys.exit(1)
def check_bind_error(e, address, bind_addresses):
"""
This method checks an exception occurred while binding on 0.0.0.0.
If :: is specified in the bind addresses a warning is shown.
The exception is still raised otherwise.
Binding on both 0.0.0.0 and :: causes an exception on Linux and macOS
because :: binds on both IPv4 and IPv6 (as per RFC 3493).
When binding on 0.0.0.0 after :: this can safely be ignored.
Args:
e (Exception): Exception that was caught.
address (str): Address on which binding was attempted.
bind_addresses (list): Addresses on which the service listens.
"""
if address == '0.0.0.0' and '::' in bind_addresses:
logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]')
else:
raise e

View file

@ -15,18 +15,38 @@
import gc
import logging
import signal
import sys
import traceback
import psutil
from daemonize import Daemonize
from twisted.internet import error, reactor
from twisted.protocols.tls import TLSMemoryBIOFactory
import synapse
from synapse.app import check_bind_error
from synapse.crypto import context_factory
from synapse.util import PreserveLoggingContext
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
_sighup_callbacks = []
def register_sighup(func):
"""
Register a function to be called when a SIGHUP occurs.
Args:
func (function): Function to be called when sent a SIGHUP signal.
Will be called with a single argument, the homeserver.
"""
_sighup_callbacks.append(func)
def start_worker_reactor(appname, config):
""" Run the reactor in the main process
@ -135,62 +155,154 @@ def listen_metrics(bind_addresses, port):
from prometheus_client import start_http_server
for host in bind_addresses:
reactor.callInThread(start_http_server, int(port),
addr=host, registry=RegistryProxy)
logger.info("Metrics now reporting on %s:%d", host, port)
logger.info("Starting metrics listener on %s:%d", host, port)
start_http_server(port, addr=host, registry=RegistryProxy)
def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
"""
Create a TCP socket for a port and several addresses
Returns:
list[twisted.internet.tcp.Port]: listening for TCP connections
"""
r = []
for address in bind_addresses:
try:
reactor.listenTCP(
port,
factory,
backlog,
address
r.append(
reactor.listenTCP(
port,
factory,
backlog,
address
)
)
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
return r
def listen_ssl(
bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50
):
"""
Create an SSL socket for a port and several addresses
Create an TLS-over-TCP socket for a port and several addresses
Returns:
list of twisted.internet.tcp.Port listening for TLS connections
"""
r = []
for address in bind_addresses:
try:
reactor.listenSSL(
port,
factory,
context_factory,
backlog,
address
r.append(
reactor.listenSSL(
port,
factory,
context_factory,
backlog,
address
)
)
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
return r
def check_bind_error(e, address, bind_addresses):
def refresh_certificate(hs):
"""
Refresh the TLS certificates that Synapse is using by re-reading them from
disk and updating the TLS context factories to use them.
"""
This method checks an exception occurred while binding on 0.0.0.0.
If :: is specified in the bind addresses a warning is shown.
The exception is still raised otherwise.
Binding on both 0.0.0.0 and :: causes an exception on Linux and macOS
because :: binds on both IPv4 and IPv6 (as per RFC 3493).
When binding on 0.0.0.0 after :: this can safely be ignored.
if not hs.config.has_tls_listener():
# attempt to reload the certs for the good of the tls_fingerprints
hs.config.read_certificate_from_disk(require_cert_and_key=False)
return
hs.config.read_certificate_from_disk(require_cert_and_key=True)
hs.tls_server_context_factory = context_factory.ServerContextFactory(hs.config)
if hs._listening_services:
logger.info("Updating context factories...")
for i in hs._listening_services:
# When you listenSSL, it doesn't make an SSL port but a TCP one with
# a TLS wrapping factory around the factory you actually want to get
# requests. This factory attribute is public but missing from
# Twisted's documentation.
if isinstance(i.factory, TLSMemoryBIOFactory):
addr = i.getHost()
logger.info(
"Replacing TLS context factory on [%s]:%i", addr.host, addr.port,
)
# We want to replace TLS factories with a new one, with the new
# TLS configuration. We do this by reaching in and pulling out
# the wrappedFactory, and then re-wrapping it.
i.factory = TLSMemoryBIOFactory(
hs.tls_server_context_factory,
False,
i.factory.wrappedFactory
)
logger.info("Context factories updated.")
def start(hs, listeners=None):
"""
Start a Synapse server or worker.
Args:
e (Exception): Exception that was caught.
address (str): Address on which binding was attempted.
bind_addresses (list): Addresses on which the service listens.
hs (synapse.server.HomeServer)
listeners (list[dict]): Listener configuration ('listeners' in homeserver.yaml)
"""
if address == '0.0.0.0' and '::' in bind_addresses:
logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]')
else:
raise e
try:
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):
def handle_sighup(*args, **kwargs):
for i in _sighup_callbacks:
i(hs)
signal.signal(signal.SIGHUP, handle_sighup)
register_sighup(refresh_certificate)
# Load the certificate from disk.
refresh_certificate(hs)
# It is now safe to start your Synapse.
hs.start_listening(listeners)
hs.get_datastore().start_profiling()
setup_sentry(hs)
except Exception:
traceback.print_exc(file=sys.stderr)
reactor = hs.get_reactor()
if reactor.running:
reactor.stop()
sys.exit(1)
def setup_sentry(hs):
"""Enable sentry integration, if enabled in configuration
Args:
hs (synapse.server.HomeServer)
"""
if not hs.config.sentry_enabled:
return
import sentry_sdk
sentry_sdk.init(
dsn=hs.config.sentry_dsn,
release=get_version_string(synapse),
)
# We set some default tags that give some context to this instance
with sentry_sdk.configure_scope() as scope:
scope.set_tag("matrix_server_name", hs.config.server_name)
app = hs.config.worker_app if hs.config.worker_app else "synapse.app.homeserver"
name = hs.config.worker_name if hs.config.worker_name else "master"
scope.set_tag("worker_app", app)
scope.set_tag("worker_name", name)

View file

@ -168,12 +168,7 @@ def start(config_options):
)
ps.setup()
ps.start_listening(config.worker_listeners)
def start():
ps.get_datastore().start_profiling()
reactor.callWhenRunning(start)
reactor.callWhenRunning(_base.start, ps, config.worker_listeners)
_base.start_worker_reactor("synapse-appservice", config)

View file

@ -25,7 +25,6 @@ from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.crypto import context_factory
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite
from synapse.metrics import RegistryProxy
@ -41,6 +40,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.login import LoginRestServlet
from synapse.rest.client.v1.room import (
JoinedRoomMemberListRestServlet,
PublicRoomListRestServlet,
@ -48,6 +48,7 @@ from synapse.rest.client.v1.room import (
RoomMemberListRestServlet,
RoomStateRestServlet,
)
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
@ -93,6 +94,8 @@ class ClientReaderServer(HomeServer):
JoinedRoomMemberListRestServlet(self).register(resource)
RoomStateRestServlet(self).register(resource)
RoomEventContextServlet(self).register(resource)
RegisterRestServlet(self).register(resource)
LoginRestServlet(self).register(resource)
resources.update({
"/_matrix/client/r0": resource,
@ -164,26 +167,16 @@ def start(config_options):
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = ClientReaderServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.start_listening(config.worker_listeners)
def start():
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
_base.start_worker_reactor("synapse-client-reader", config)

View file

@ -25,7 +25,6 @@ from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.crypto import context_factory
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite
from synapse.metrics import RegistryProxy
@ -185,26 +184,16 @@ def start(config_options):
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = EventCreatorServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.start_listening(config.worker_listeners)
def start():
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
_base.start_worker_reactor("synapse-event-creator", config)

View file

@ -26,7 +26,6 @@ from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.crypto import context_factory
from synapse.federation.transport.server import TransportLayerServer
from synapse.http.site import SynapseSite
from synapse.metrics import RegistryProxy
@ -41,6 +40,7 @@ from synapse.replication.slave.storage.profile import SlavedProfileStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
@ -63,6 +63,7 @@ class FederationReaderSlavedStore(
SlavedReceiptsStore,
SlavedEventStore,
SlavedKeyStore,
SlavedRegistrationStore,
RoomStore,
DirectoryStore,
SlavedTransactionStore,
@ -87,6 +88,16 @@ class FederationReaderServer(HomeServer):
resources.update({
FEDERATION_PREFIX: TransportLayerServer(self),
})
if name == "openid" and "federation" not in res["names"]:
# Only load the openid resource separately if federation resource
# is not specified since federation resource includes openid
# resource.
resources.update({
FEDERATION_PREFIX: TransportLayerServer(
self,
servlet_groups=["openid"],
),
})
root_resource = create_resource_tree(resources, NoResource())
@ -99,7 +110,8 @@ class FederationReaderServer(HomeServer):
listener_config,
root_resource,
self.version_string,
)
),
reactor=self.get_reactor()
)
logger.info("Synapse federation reader now listening on port %d", port)
@ -151,26 +163,16 @@ def start(config_options):
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = FederationReaderServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.start_listening(config.worker_listeners)
def start():
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
_base.start_worker_reactor("synapse-federation-reader", config)

View file

@ -25,7 +25,6 @@ from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.crypto import context_factory
from synapse.federation import send_queue
from synapse.http.site import SynapseSite
from synapse.metrics import RegistryProxy
@ -183,26 +182,17 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.send_federation = True
tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ps = FederationSenderServer(
ss = FederationSenderServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ps.setup()
ps.start_listening(config.worker_listeners)
ss.setup()
reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
def start():
ps.get_datastore().start_profiling()
reactor.callWhenRunning(start)
_base.start_worker_reactor("synapse-federation-sender", config)

View file

@ -21,12 +21,11 @@ from twisted.web.resource import NoResource
import synapse
from synapse import events
from synapse.api.errors import SynapseError
from synapse.api.errors import HttpResponseException, SynapseError
from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.crypto import context_factory
from synapse.http.server import JsonResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseSite
@ -67,10 +66,15 @@ class PresenceStatusStubServlet(ClientV1RestServlet):
headers = {
"Authorization": auth_headers,
}
result = yield self.http_client.get_json(
self.main_uri + request.uri.decode('ascii'),
headers=headers,
)
try:
result = yield self.http_client.get_json(
self.main_uri + request.uri.decode('ascii'),
headers=headers,
)
except HttpResponseException as e:
raise e.to_synapse_error()
defer.returnValue((200, result))
@defer.inlineCallbacks
@ -241,26 +245,16 @@ def start(config_options):
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = FrontendProxyServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.start_listening(config.worker_listeners)
def start():
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
_base.start_worker_reactor("synapse-frontend-proxy", config)

View file

@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,6 +14,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gc
import logging
import os
@ -25,6 +29,7 @@ from prometheus_client import Gauge
from twisted.application import service
from twisted.internet import defer, reactor
from twisted.python.failure import Failure
from twisted.web.resource import EncodingResourceWrapper, NoResource
from twisted.web.server import GzipEncoderFactory
from twisted.web.static import File
@ -37,7 +42,6 @@ from synapse.api.urls import (
FEDERATION_PREFIX,
LEGACY_MEDIA_PREFIX,
MEDIA_PREFIX,
SERVER_KEY_PREFIX,
SERVER_KEY_V2_PREFIX,
STATIC_PREFIX,
WEB_CLIENT_PREFIX,
@ -46,7 +50,6 @@ from synapse.app import _base
from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.federation.transport.server import TransportLayerServer
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import RootRedirect
@ -55,13 +58,13 @@ from synapse.metrics import RegistryProxy
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.module_api import ModuleApi
from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, check_requirements
from synapse.python_dependencies import check_requirements
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest import ClientRestResource
from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer
from synapse.storage import DataStore, are_all_users_on_domain
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
@ -81,36 +84,6 @@ def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
def build_resource_for_web_client(hs):
webclient_path = hs.get_config().web_client_location
if not webclient_path:
try:
import syweb
except ImportError:
quit_with_error(
"Could not find a webclient.\n\n"
"Please either install the matrix-angular-sdk or configure\n"
"the location of the source to serve via the configuration\n"
"option `web_client_location`\n\n"
"To install the `matrix-angular-sdk` via pip, run:\n\n"
" pip install '%(dep)s'\n"
"\n"
"You can also disable hosting of the webclient via the\n"
"configuration option `web_client`\n"
% {"dep": CONDITIONAL_REQUIREMENTS["web_client"].keys()[0]}
)
syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient")
# GZip is disabled here due to
# https://twistedmatrix.com/trac/ticket/7678
# (It can stay enabled for the API resources: they call
# write() with the whole body and then finish() straight
# after and so do not trigger the bug.
# GzipFile was removed in commit 184ba09
# return GzipFile(webclient_path) # TODO configurable?
return File(webclient_path) # TODO configurable?
class SynapseHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
@ -120,12 +93,13 @@ class SynapseHomeServer(HomeServer):
tls = listener_config.get("tls", False)
site_tag = listener_config.get("tag", port)
if tls and config.no_tls:
return
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "openid" and "federation" in res["names"]:
# Skip loading openid resource if federation is defined
# since federation resource will include openid
continue
resources.update(self._configure_named_resource(
name, res.get("compress", False),
))
@ -139,15 +113,18 @@ class SynapseHomeServer(HomeServer):
handler = handler_cls(config, module_api)
resources[path] = AdditionalResource(self, handler.handle_request)
# try to find something useful to redirect '/' to
if WEB_CLIENT_PREFIX in resources:
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
elif STATIC_PREFIX in resources:
root_resource = RootRedirect(STATIC_PREFIX)
else:
root_resource = NoResource()
root_resource = create_resource_tree(resources, root_resource)
if tls:
listen_ssl(
ports = listen_ssl(
bind_addresses,
port,
SynapseSite(
@ -158,10 +135,12 @@ class SynapseHomeServer(HomeServer):
self.version_string,
),
self.tls_server_context_factory,
reactor=self.get_reactor(),
)
logger.info("Synapse now listening on TCP port %d (TLS)", port)
else:
listen_tcp(
ports = listen_tcp(
bind_addresses,
port,
SynapseSite(
@ -170,9 +149,12 @@ class SynapseHomeServer(HomeServer):
listener_config,
root_resource,
self.version_string,
)
),
reactor=self.get_reactor(),
)
logger.info("Synapse now listening on port %d", port)
logger.info("Synapse now listening on TCP port %d", port)
return ports
def _configure_named_resource(self, name, compress=False):
"""Build a resource map for a named resource
@ -197,8 +179,13 @@ class SynapseHomeServer(HomeServer):
"/_matrix/client/unstable": client_resource,
"/_matrix/client/v2_alpha": client_resource,
"/_matrix/client/versions": client_resource,
"/.well-known/matrix/client": WellKnownResource(self),
})
if self.get_config().saml2_enabled:
from synapse.rest.saml2 import SAML2Resource
resources["/_matrix/saml2"] = SAML2Resource(self)
if name == "consent":
from synapse.rest.consent.consent_resource import ConsentResource
consent_resource = ConsentResource(self)
@ -213,6 +200,11 @@ class SynapseHomeServer(HomeServer):
FEDERATION_PREFIX: TransportLayerServer(self),
})
if name == "openid":
resources.update({
FEDERATION_PREFIX: TransportLayerServer(self, servlet_groups=["openid"]),
})
if name in ["static", "client"]:
resources.update({
STATIC_PREFIX: File(
@ -236,13 +228,19 @@ class SynapseHomeServer(HomeServer):
)
if name in ["keys", "federation"]:
resources.update({
SERVER_KEY_PREFIX: LocalKey(self),
SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
})
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
if name == "webclient":
resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
webclient_path = self.get_config().web_client_location
if webclient_path is None:
logger.warning(
"Not enabling webclient resource, as web_client_location is unset."
)
else:
# GZip is disabled here due to
# https://twistedmatrix.com/trac/ticket/7678
resources[WEB_CLIENT_PREFIX] = File(webclient_path)
if name == "metrics" and self.get_config().enable_metrics:
resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
@ -252,12 +250,14 @@ class SynapseHomeServer(HomeServer):
return resources
def start_listening(self):
def start_listening(self, listeners):
config = self.get_config()
for listener in config.listeners:
for listener in listeners:
if listener["type"] == "http":
self._listener_http(config, listener)
self._listening_services.extend(
self._listener_http(config, listener)
)
elif listener["type"] == "manhole":
listen_tcp(
listener["bind_addresses"],
@ -269,14 +269,14 @@ class SynapseHomeServer(HomeServer):
)
)
elif listener["type"] == "replication":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
factory = ReplicationStreamProtocolFactory(self)
server_listener = reactor.listenTCP(
listener["port"], factory, interface=address
)
services = listen_tcp(
listener["bind_addresses"],
listener["port"],
ReplicationStreamProtocolFactory(self),
)
for s in services:
reactor.addSystemEventTrigger(
"before", "shutdown", server_listener.stopListening,
"before", "shutdown", s.stopListening,
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
@ -337,24 +337,19 @@ def setup(config_options):
# generating config files and shouldn't try to continue.
sys.exit(0)
synapse.config.logger.setup_logging(config, use_worker_options=False)
# check any extra requirements we have now we have a config
check_requirements(config)
synapse.config.logger.setup_logging(
config,
use_worker_options=False
)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
database_engine = create_engine(config.database_config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
hs = SynapseHomeServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
@ -381,12 +376,79 @@ def setup(config_options):
logger.info("Database prepared in %s.", config.database_config['name'])
hs.setup()
hs.start_listening()
@defer.inlineCallbacks
def do_acme():
"""
Reprovision an ACME certificate, if it's required.
Returns:
Deferred[bool]: Whether the cert has been updated.
"""
acme = hs.get_acme_handler()
# Check how long the certificate is active for.
cert_days_remaining = hs.config.is_disk_cert_valid(
allow_self_signed=False
)
# We want to reprovision if cert_days_remaining is None (meaning no
# certificate exists), or the days remaining number it returns
# is less than our re-registration threshold.
provision = False
if (
cert_days_remaining is None or
cert_days_remaining < hs.config.acme_reprovision_threshold
):
provision = True
if provision:
yield acme.provision_certificate()
defer.returnValue(provision)
@defer.inlineCallbacks
def reprovision_acme():
"""
Provision a certificate from ACME, if required, and reload the TLS
certificate if it's renewed.
"""
reprovisioned = yield do_acme()
if reprovisioned:
_base.refresh_certificate(hs)
@defer.inlineCallbacks
def start():
hs.get_pusherpool().start()
hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates()
try:
# Run the ACME provisioning code, if it's enabled.
if hs.config.acme_enabled:
acme = hs.get_acme_handler()
# Start up the webservices which we will respond to ACME
# challenges with, and then provision.
yield acme.start_listening()
yield do_acme()
# Check if it needs to be reprovisioned every day.
hs.get_clock().looping_call(
reprovision_acme,
24 * 60 * 60 * 1000
)
_base.start(hs, config.listeners)
hs.get_pusherpool().start()
hs.get_datastore().start_doing_background_updates()
except Exception:
# Print the exception and bail out.
print("Error during startup:", file=sys.stderr)
# this gives better tracebacks than traceback.print_exc()
Failure().printTraceback(file=sys.stderr)
if reactor.running:
reactor.stop()
sys.exit(1)
reactor.callWhenRunning(start)
@ -394,7 +456,8 @@ def setup(config_options):
class SynapseService(service.Service):
"""A twisted Service class that will start synapse. Used to run synapse
"""
A twisted Service class that will start synapse. Used to run synapse
via twistd and a .tac.
"""
def __init__(self, config):
@ -540,7 +603,7 @@ def run(hs):
current_mau_count = 0
reserved_count = 0
store = hs.get_datastore()
if hs.config.limit_usage_by_mau:
if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
current_mau_count = yield store.get_monthly_active_count()
reserved_count = yield store.get_registered_reserved_users_count()
current_mau_gauge.set(float(current_mau_count))
@ -554,7 +617,7 @@ def run(hs):
)
start_generate_monthly_active_users()
if hs.config.limit_usage_by_mau:
if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000)
# End of monthly active user settings

View file

@ -26,7 +26,6 @@ from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.crypto import context_factory
from synapse.http.site import SynapseSite
from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
@ -151,26 +150,16 @@ def start(config_options):
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = MediaRepositoryServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.start_listening(config.worker_listeners)
def start():
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
_base.start_worker_reactor("synapse-media-repository", config)

View file

@ -224,11 +224,10 @@ def start(config_options):
)
ps.setup()
ps.start_listening(config.worker_listeners)
def start():
_base.start(ps, config.worker_listeners)
ps.get_pusherpool().start()
ps.get_datastore().start_profiling()
reactor.callWhenRunning(start)

View file

@ -226,7 +226,15 @@ class SynchrotronPresence(object):
class SynchrotronTyping(object):
def __init__(self, hs):
self._latest_room_serial = 0
self._reset()
def _reset(self):
"""
Reset the typing handler's data caches.
"""
# map room IDs to serial numbers
self._room_serials = {}
# map room IDs to sets of users currently typing
self._room_typing = {}
def stream_positions(self):
@ -236,6 +244,12 @@ class SynchrotronTyping(object):
return {"typing": self._latest_room_serial}
def process_replication_rows(self, token, rows):
if self._latest_room_serial > token:
# The master has gone backwards. To prevent inconsistent data, just
# clear everything.
self._reset()
# Set the latest serial token to whatever the server gave us.
self._latest_room_serial = token
for row in rows:
@ -431,12 +445,7 @@ def start(config_options):
)
ss.setup()
ss.start_listening(config.worker_listeners)
def start():
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
_base.start_worker_reactor("synapse-synchrotron", config)

View file

@ -26,7 +26,6 @@ from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.crypto import context_factory
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite
from synapse.metrics import RegistryProxy
@ -211,26 +210,16 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.update_user_directory = True
tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ps = UserDirectoryServer(
ss = UserDirectoryServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ps.setup()
ps.start_listening(config.worker_listeners)
def start():
ps.get_datastore().start_profiling()
reactor.callWhenRunning(start)
ss.setup()
reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
_base.start_worker_reactor("synapse-user-dir", config)

View file

@ -53,8 +53,8 @@ import logging
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.logcontext import run_in_background
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@ -104,14 +104,23 @@ class _ServiceQueuer(object):
self.clock = clock
def enqueue(self, service, event):
# if this service isn't being sent something
self.queued_events.setdefault(service.id, []).append(event)
run_in_background(self._send_request, service)
# start a sender for this appservice if we don't already have one
if service.id in self.requests_in_flight:
return
run_as_background_process(
"as-sender-%s" % (service.id, ),
self._send_request, service,
)
@defer.inlineCallbacks
def _send_request(self, service):
if service.id in self.requests_in_flight:
return
# sanity-check: we shouldn't get here if this service already has a sender
# running.
assert(service.id not in self.requests_in_flight)
self.requests_in_flight.add(service.id)
try:
@ -119,12 +128,10 @@ class _ServiceQueuer(object):
events = self.queued_events.pop(service.id, [])
if not events:
return
with Measure(self.clock, "servicequeuer.send"):
try:
yield self.txn_ctrl.send(service, events)
except Exception:
logger.exception("AS request failed")
try:
yield self.txn_ctrl.send(service, events)
except Exception:
logger.exception("AS request failed")
finally:
self.requests_in_flight.discard(service.id)
@ -223,7 +230,12 @@ class _Recoverer(object):
self.backoff_counter = 1
def recover(self):
self.clock.call_later((2 ** self.backoff_counter), self.retry)
def _retry():
run_as_background_process(
"as-recoverer-%s" % (self.service.id,),
self.retry,
)
self.clock.call_later((2 ** self.backoff_counter), _retry)
def _backoff(self):
# cap the backoff to be around 8.5min => (2^9) = 512 secs

View file

@ -16,7 +16,7 @@ from synapse.config._base import ConfigError
if __name__ == "__main__":
import sys
from homeserver import HomeServerConfig
from synapse.config.homeserver import HomeServerConfig
action = sys.argv[1]

View file

@ -134,10 +134,6 @@ class Config(object):
with open(file_path) as file_stream:
return file_stream.read()
@staticmethod
def default_path(name):
return os.path.abspath(os.path.join(os.path.curdir, name))
@staticmethod
def read_config_file(file_path):
with open(file_path) as file_stream:
@ -151,8 +147,39 @@ class Config(object):
return results
def generate_config(
self, config_dir_path, server_name, is_generating_file, report_stats=None
self,
config_dir_path,
data_dir_path,
server_name,
generate_secrets=False,
report_stats=None,
):
"""Build a default configuration file
This is used both when the user explicitly asks us to generate a config file
(eg with --generate_config), and before loading the config at runtime (to give
a base which the config files override)
Args:
config_dir_path (str): The path where the config files are kept. Used to
create filenames for things like the log config and the signing key.
data_dir_path (str): The path where the data files are kept. Used to create
filenames for things like the database and media store.
server_name (str): The server name. Used to initialise the server_name
config param, but also used in the names of some of the config files.
generate_secrets (bool): True if we should generate new secrets for things
like the macaroon_secret_key. If False, these parameters will be left
unset.
report_stats (bool|None): Initial setting for the report_stats setting.
If None, report_stats will be left unset.
Returns:
str: the yaml config file
"""
default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join(
@ -160,15 +187,14 @@ class Config(object):
for conf in self.invoke_all(
"default_config",
config_dir_path=config_dir_path,
data_dir_path=data_dir_path,
server_name=server_name,
is_generating_file=is_generating_file,
generate_secrets=generate_secrets,
report_stats=report_stats,
)
)
config = yaml.load(default_config)
return default_config, config
return default_config
@classmethod
def load_config(cls, description, argv):
@ -231,7 +257,7 @@ class Config(object):
"--keys-directory",
metavar="DIRECTORY",
help="Used with 'generate-*' options to specify where files such as"
" certs and signing keys should be stored in, unless explicitly"
" signing keys should be stored, unless explicitly"
" specified in the config.",
)
config_parser.add_argument(
@ -274,27 +300,24 @@ class Config(object):
if not cls.path_exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "w") as config_file:
config_str, config = obj.generate_config(
config_str = obj.generate_config(
config_dir_path=config_dir_path,
data_dir_path=os.getcwd(),
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
is_generating_file=True,
generate_secrets=True,
)
config = yaml.load(config_str)
obj.invoke_all("generate_files", config)
config_file.write(config_str)
print(
(
"A config file has been generated in %r for server name"
" %r with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it"
" %r. Please review this file and customise it"
" to your needs."
)
% (config_path, server_name)
)
print(
"If this server name is incorrect, you will need to"
" regenerate the SSL certificates"
)
return
else:
print(
@ -339,7 +362,7 @@ class Config(object):
if not keys_directory:
keys_directory = os.path.dirname(config_files[-1])
config_dir_path = os.path.abspath(keys_directory)
self.config_dir_path = os.path.abspath(keys_directory)
specified_config = {}
for config_file in config_files:
@ -350,11 +373,13 @@ class Config(object):
raise ConfigError(MISSING_SERVER_NAME)
server_name = specified_config["server_name"]
_, config = self.generate_config(
config_dir_path=config_dir_path,
config_string = self.generate_config(
config_dir_path=self.config_dir_path,
data_dir_path=os.getcwd(),
server_name=server_name,
is_generating_file=False,
generate_secrets=False,
)
config = yaml.load(config_string)
config.pop("log_config")
config.update(specified_config)

View file

@ -24,6 +24,7 @@ class ApiConfig(Config):
EventTypes.JoinRules,
EventTypes.CanonicalAlias,
EventTypes.RoomAvatar,
EventTypes.RoomEncryption,
EventTypes.Name,
])
@ -32,9 +33,11 @@ class ApiConfig(Config):
## API Configuration ##
# A list of event types that will be included in the room_invite_state
#
room_invite_state_types:
- "{JoinRules}"
- "{CanonicalAlias}"
- "{RoomAvatar}"
- "{RoomEncryption}"
- "{Name}"
""".format(**vars(EventTypes))

View file

@ -33,11 +33,18 @@ class AppServiceConfig(Config):
def read_config(self, config):
self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
self.track_appservice_user_ips = config.get("track_appservice_user_ips", False)
def default_config(cls, **kwargs):
return """\
# A list of application service config file to use
#
app_service_config_files: []
# Whether or not to track application service IP addresses. Implicitly
# enables MAU tracking for application service users.
#
track_appservice_user_ips: False
"""

View file

@ -30,14 +30,17 @@ class CaptchaConfig(Config):
# See docs/CAPTCHA_SETUP for full details of configuring this.
# This Home Server's ReCAPTCHA public key.
#
recaptcha_public_key: "YOUR_PUBLIC_KEY"
# This Home Server's ReCAPTCHA private key.
#
recaptcha_private_key: "YOUR_PRIVATE_KEY"
# Enables ReCaptcha checks when registering, preventing signup
# unless a captcha is answered. Requires a valid ReCaptcha
# public/private key.
#
enable_registration_captcha: False
# A secret key used to bypass the captcha test entirely.

View file

@ -38,6 +38,7 @@ class CasConfig(Config):
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable CAS for registration and login.
#
#cas_config:
# enabled: true
# server_url: "https://cas-server.com"

View file

@ -13,6 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from os import path
from synapse.config import ConfigError
from ._base import Config
DEFAULT_CONFIG = """\
@ -42,18 +46,28 @@ DEFAULT_CONFIG = """\
# until the user consents to the privacy policy. The value of the setting is
# used as the text of the error.
#
# user_consent:
# template_dir: res/templates/privacy
# version: 1.0
# server_notice_content:
# msgtype: m.text
# body: >-
# To continue using this homeserver you must review and agree to the
# terms and conditions at %(consent_uri)s
# send_server_notice_to_guests: True
# block_events_error: >-
# To continue using this homeserver you must review and agree to the
# terms and conditions at %(consent_uri)s
# 'require_at_registration', if enabled, will add a step to the registration
# process, similar to how captcha works. Users will be required to accept the
# policy before their account is created.
#
# 'policy_name' is the display name of the policy users will see when registering
# for an account. Has no effect unless `require_at_registration` is enabled.
# Defaults to "Privacy Policy".
#
#user_consent:
# template_dir: res/templates/privacy
# version: 1.0
# server_notice_content:
# msgtype: m.text
# body: >-
# To continue using this homeserver you must review and agree to the
# terms and conditions at %(consent_uri)s
# send_server_notice_to_guests: True
# block_events_error: >-
# To continue using this homeserver you must review and agree to the
# terms and conditions at %(consent_uri)s
# require_at_registration: False
# policy_name: Privacy Policy
#
"""
@ -67,13 +81,23 @@ class ConsentConfig(Config):
self.user_consent_server_notice_content = None
self.user_consent_server_notice_to_guests = False
self.block_events_without_consent_error = None
self.user_consent_at_registration = False
self.user_consent_policy_name = "Privacy Policy"
def read_config(self, config):
consent_config = config.get("user_consent")
if consent_config is None:
return
self.user_consent_version = str(consent_config["version"])
self.user_consent_template_dir = consent_config["template_dir"]
self.user_consent_template_dir = self.abspath(
consent_config["template_dir"]
)
if not path.isdir(self.user_consent_template_dir):
raise ConfigError(
"Could not find template directory '%s'" % (
self.user_consent_template_dir,
),
)
self.user_consent_server_notice_content = consent_config.get(
"server_notice_content",
)
@ -83,6 +107,12 @@ class ConsentConfig(Config):
self.user_consent_server_notice_to_guests = bool(consent_config.get(
"send_server_notice_to_guests", False,
))
self.user_consent_at_registration = bool(consent_config.get(
"require_at_registration", False,
))
self.user_consent_policy_name = consent_config.get(
"policy_name", "Privacy Policy",
)
def default_config(self, **kwargs):
return DEFAULT_CONFIG

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from ._base import Config
@ -45,8 +46,8 @@ class DatabaseConfig(Config):
self.set_databasepath(config.get("database_path"))
def default_config(self, **kwargs):
database_path = self.abspath("homeserver.db")
def default_config(self, data_dir_path, **kwargs):
database_path = os.path.join(data_dir_path, "homeserver.db")
return """\
# Database configuration
database:

View file

@ -24,9 +24,11 @@ class GroupsConfig(Config):
def default_config(self, **kwargs):
return """\
# Whether to allow non server admins to create groups on this server
#
enable_group_creation: false
# If enabled, non server admins can only create groups with local parts
# starting with this prefix
# group_creation_prefix: "unofficial/"
#
#group_creation_prefix: "unofficial/"
"""

View file

@ -32,7 +32,7 @@ from .ratelimiting import RatelimitConfig
from .registration import RegistrationConfig
from .repository import ContentRepositoryConfig
from .room_directory import RoomDirectoryConfig
from .saml2 import SAML2Config
from .saml2_config import SAML2Config
from .server import ServerConfig
from .server_notices_config import ServerNoticesConfig
from .spam_checker import SpamCheckerConfig
@ -42,7 +42,7 @@ from .voip import VoipConfig
from .workers import WorkerConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
class HomeServerConfig(ServerConfig, TlsConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
@ -53,10 +53,3 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
ServerNoticesConfig, RoomDirectoryConfig,
):
pass
if __name__ == '__main__':
import sys
sys.stdout.write(
HomeServerConfig().generate_config(sys.argv[1], sys.argv[2], True)[0]
)

View file

@ -46,8 +46,8 @@ class JWTConfig(Config):
return """\
# The JWT needs to contain a globally unique "sub" (subject) claim.
#
# jwt_config:
# enabled: true
# secret: "a secret"
# algorithm: "HS256"
#jwt_config:
# enabled: true
# secret: "a secret"
# algorithm: "HS256"
"""

View file

@ -40,7 +40,7 @@ class KeyConfig(Config):
def read_config(self, config):
self.signing_key = self.read_signing_key(config["signing_key_path"])
self.old_signing_keys = self.read_old_signing_keys(
config["old_signing_keys"]
config.get("old_signing_keys", {})
)
self.key_refresh_interval = self.parse_duration(
config["key_refresh_interval"]
@ -56,9 +56,9 @@ class KeyConfig(Config):
if not self.macaroon_secret_key:
# Unfortunately, there are people out there that don't have this
# set. Lets just be "nice" and derive one from their secret key.
logger.warn("Config is missing missing macaroon_secret_key")
seed = self.signing_key[0].seed
self.macaroon_secret_key = hashlib.sha256(seed)
logger.warn("Config is missing macaroon_secret_key")
seed = bytes(self.signing_key[0])
self.macaroon_secret_key = hashlib.sha256(seed).digest()
self.expire_access_token = config.get("expire_access_token", False)
@ -66,35 +66,46 @@ class KeyConfig(Config):
# falsification of values
self.form_secret = config.get("form_secret", None)
def default_config(self, config_dir_path, server_name, is_generating_file=False,
def default_config(self, config_dir_path, server_name, generate_secrets=False,
**kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
if is_generating_file:
macaroon_secret_key = random_string_with_symbols(50)
form_secret = '"%s"' % random_string_with_symbols(50)
if generate_secrets:
macaroon_secret_key = 'macaroon_secret_key: "%s"' % (
random_string_with_symbols(50),
)
form_secret = 'form_secret: "%s"' % random_string_with_symbols(50)
else:
macaroon_secret_key = None
form_secret = 'null'
macaroon_secret_key = "# macaroon_secret_key: <PRIVATE STRING>"
form_secret = "# form_secret: <PRIVATE STRING>"
return """\
macaroon_secret_key: "%(macaroon_secret_key)s"
# a secret which is used to sign access tokens. If none is specified,
# the registration_shared_secret is used, if one is given; otherwise,
# a secret key is derived from the signing key.
#
%(macaroon_secret_key)s
# Used to enable access token expiration.
#
expire_access_token: False
# a secret which is used to calculate HMACs for form values, to stop
# falsification of values
form_secret: %(form_secret)s
# falsification of values. Must be specified for the User Consent
# forms to work.
#
%(form_secret)s
## Signing Keys ##
# Path to the signing key to sign messages with
#
signing_key_path: "%(base_key_name)s.signing.key"
# The keys that the server used to sign messages with but won't use
# to sign new messages. E.g. it has lost its private key
old_signing_keys: {}
#
#old_signing_keys:
# "ed25519:auto":
# # Base64 encoded public key
# key: "The public part of your old signing key."
@ -105,9 +116,11 @@ class KeyConfig(Config):
# Used to set the valid_until_ts in /key/v2 APIs.
# Determines how quickly servers will query to check which keys
# are still valid.
#
key_refresh_interval: "1d" # 1 Day.
# The trusted servers to download signing keys from.
#
perspectives:
servers:
"matrix.org":

View file

@ -15,7 +15,6 @@
import logging
import logging.config
import os
import signal
import sys
from string import Template
@ -24,6 +23,7 @@ import yaml
from twisted.logger import STDLibLogObserver, globalLogBeginner
import synapse
from synapse.app import _base as appbase
from synapse.util.logcontext import LoggingContextFilter
from synapse.util.versionstring import get_version_string
@ -50,6 +50,7 @@ handlers:
maxBytes: 104857600
backupCount: 10
filters: [context]
encoding: utf8
console:
class: logging.StreamHandler
formatter: precise
@ -79,11 +80,10 @@ class LoggingConfig(Config):
self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name, **kwargs):
log_config = self.abspath(
os.path.join(config_dir_path, server_name + ".log.config")
)
log_config = os.path.join(config_dir_path, server_name + ".log.config")
return """
# A yaml python logging config file
#
log_config: "%(log_config)s"
""" % locals()
@ -137,6 +137,9 @@ def setup_logging(config, use_worker_options=False):
use_worker_options (bool): True to use 'worker_log_config' and
'worker_log_file' options instead of 'log_config' and 'log_file'.
register_sighup (func | None): Function to call to register a
sighup handler.
"""
log_config = (config.worker_log_config if use_worker_options
else config.log_config)
@ -179,7 +182,7 @@ def setup_logging(config, use_worker_options=False):
else:
handler = logging.StreamHandler()
def sighup(signum, stack):
def sighup(*args):
pass
handler.setFormatter(formatter)
@ -192,20 +195,14 @@ def setup_logging(config, use_worker_options=False):
with open(log_config, 'r') as f:
logging.config.dictConfig(yaml.load(f))
def sighup(signum, stack):
def sighup(*args):
# it might be better to use a file watcher or something for this.
load_log_config()
logging.info("Reloaded log config from %s due to SIGHUP", log_config)
load_log_config()
# TODO(paul): obviously this is a terrible mechanism for
# stealing SIGHUP, because it means no other part of synapse
# can use it instead. If we want to catch SIGHUP anywhere
# else as well, I'd suggest we find a nicer way to broadcast
# it around.
if getattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, sighup)
appbase.register_sighup(sighup)
# make sure that the first thing we log is a thing we can grep backwards
# for
@ -246,3 +243,5 @@ def setup_logging(config, use_worker_options=False):
[_log],
redirectStandardIO=not config.no_redirect_stdio,
)
if not config.no_redirect_stdio:
print("Redirected stdout/stderr to logs")

View file

@ -13,7 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
from ._base import Config, ConfigError
MISSING_SENTRY = (
"""Missing sentry-sdk library. This is required to enable sentry
integration.
"""
)
class MetricsConfig(Config):
@ -23,11 +29,43 @@ class MetricsConfig(Config):
self.metrics_port = config.get("metrics_port")
self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1")
self.sentry_enabled = "sentry" in config
if self.sentry_enabled:
try:
import sentry_sdk # noqa F401
except ImportError:
raise ConfigError(MISSING_SENTRY)
self.sentry_dsn = config["sentry"].get("dsn")
if not self.sentry_dsn:
raise ConfigError(
"sentry.dsn field is required when sentry integration is enabled",
)
def default_config(self, report_stats=None, **kwargs):
suffix = "" if report_stats is None else "report_stats: %(report_stats)s\n"
return ("""\
res = """\
## Metrics ###
# Enable collection and rendering of performance metrics
#
enable_metrics: False
""" + suffix) % locals()
# Enable sentry integration
# NOTE: While attempts are made to ensure that the logs don't contain
# any sensitive information, this cannot be guaranteed. By enabling
# this option the sentry server may therefore receive sensitive
# information, and it in turn may then diseminate sensitive information
# through insecure notification channels if so configured.
#
#sentry:
# dsn: "..."
# Whether or not to report anonymized homeserver usage statistics.
"""
if report_stats is None:
res += "# report_stats: true|false\n"
else:
res += "report_stats: %s\n" % ('true' if report_stats else 'false')
return res

View file

@ -28,6 +28,7 @@ class PasswordConfig(Config):
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable password for login.
#
password_config:
enabled: true
# Uncomment and change to a secret random string for extra security.

View file

@ -52,18 +52,18 @@ class PasswordAuthProviderConfig(Config):
def default_config(self, **kwargs):
return """\
# password_providers:
# - module: "ldap_auth_provider.LdapAuthProvider"
# config:
# enabled: true
# uri: "ldap://ldap.example.com:389"
# start_tls: true
# base: "ou=users,dc=example,dc=com"
# attributes:
# uid: "cn"
# mail: "email"
# name: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
#password_providers:
# - module: "ldap_auth_provider.LdapAuthProvider"
# config:
# enabled: true
# uri: "ldap://ldap.example.com:389"
# start_tls: true
# base: "ou=users,dc=example,dc=com"
# attributes:
# uid: "cn"
# mail: "email"
# name: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
"""

View file

@ -51,11 +51,11 @@ class PushConfig(Config):
# notification request includes the content of the event (other details
# like the sender are still included). For `event_id_only` push, it
# has no effect.
#
# For modern android devices the notification content will still appear
# because it is loaded by the app. iPhone, however will send a
# notification saying only that a message arrived and who it came from.
#
#push:
# include_content: true
# include_content: true
"""

View file

@ -32,27 +32,34 @@ class RatelimitConfig(Config):
## Ratelimiting ##
# Number of messages a client can send per second
#
rc_messages_per_second: 0.2
# Number of message a client can send before being throttled
#
rc_message_burst_count: 10.0
# The federation window size in milliseconds
#
federation_rc_window_size: 1000
# The number of federation requests from a single server in a window
# before the server will delay processing the request.
#
federation_rc_sleep_limit: 10
# The duration in milliseconds to delay processing events from
# remote servers by if they go over the sleep limit.
#
federation_rc_sleep_delay: 500
# The maximum number of concurrent federation requests allowed
# from a single server
#
federation_rc_reject_limit: 50
# The number of federation requests to concurrently process from a
# single server
#
federation_rc_concurrent: 3
"""

View file

@ -37,6 +37,7 @@ class RegistrationConfig(Config):
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
self.default_identity_server = config.get("default_identity_server")
self.allow_guest_access = config.get("allow_guest_access", False)
self.invite_3pid_guest = (
@ -49,8 +50,17 @@ class RegistrationConfig(Config):
raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,))
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50)
self.disable_msisdn_registration = (
config.get("disable_msisdn_registration", False)
)
def default_config(self, generate_secrets=False, **kwargs):
if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % (
random_string_with_symbols(50),
)
else:
registration_shared_secret = '# registration_shared_secret: <PRIVATE STRING>'
return """\
## Registration ##
@ -60,54 +70,75 @@ class RegistrationConfig(Config):
# The user must provide all of the below types of 3PID when registering.
#
# registrations_require_3pid:
# - email
# - msisdn
#registrations_require_3pid:
# - email
# - msisdn
# Explicitly disable asking for MSISDNs from the registration
# flow (overrides registrations_require_3pid if MSISDNs are set as required)
#
#disable_msisdn_registration: True
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#
# allowed_local_3pids:
# - medium: email
# pattern: ".*@matrix\\.org"
# - medium: email
# pattern: ".*@vector\\.im"
# - medium: msisdn
# pattern: "\\+44"
#allowed_local_3pids:
# - medium: email
# pattern: '.*@matrix\\.org'
# - medium: email
# pattern: '.*@vector\\.im'
# - medium: msisdn
# pattern: '\\+44'
# If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s"
#
%(registration_shared_secret)s
# Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash.
# The default number is 12 (which equates to 2^12 rounds).
# N.B. that increasing this will exponentially increase the time required
# to register or login - e.g. 24 => 2^24 rounds which will take >20 mins.
#
bcrypt_rounds: 12
# Allows users to register as guests without a password/email/etc, and
# participate in rooms hosted on this server which have been made
# accessible to anonymous users.
#
allow_guest_access: False
# The identity server which we suggest that clients should use when users log
# in on this server.
#
# (By default, no suggestion is made, so it is left up to the client.
# This setting is ignored unless public_baseurl is also set.)
#
#default_identity_server: https://matrix.org
# The list of identity servers trusted to verify third party
# identifiers by this server.
#
# Also defines the ID server which will be called when an account is
# deactivated (one will be picked arbitrarily).
#
trusted_third_party_id_servers:
- matrix.org
- vector.im
- riot.im
- matrix.org
- vector.im
# Users who register on this homeserver will automatically be joined
# to these rooms
#
#auto_join_rooms:
# - "#example:example.com"
# - "#example:example.com"
# Where auto_join_rooms are specified, setting this flag ensures that the
# the rooms exist by creating them when the first user on the
# homeserver registers.
# Setting to false means that if the rooms are not manually created,
# users cannot be auto-joined since they do not exist.
#
autocreate_auto_join_rooms: true
""" % locals()

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import namedtuple
from synapse.util.module_loader import load_module
@ -175,34 +175,39 @@ class ContentRepositoryConfig(Config):
"url_preview_url_blacklist", ()
)
def default_config(self, **kwargs):
media_store = self.default_path("media_store")
uploads_path = self.default_path("uploads")
def default_config(self, data_dir_path, **kwargs):
media_store = os.path.join(data_dir_path, "media_store")
uploads_path = os.path.join(data_dir_path, "uploads")
return r"""
# Directory where uploaded images and attachments are stored.
#
media_store_path: "%(media_store)s"
# Media storage providers allow media to be stored in different
# locations.
# media_storage_providers:
# - module: file_system
# # Whether to write new local files.
# store_local: false
# # Whether to write new remote media
# store_remote: false
# # Whether to block upload requests waiting for write to this
# # provider to complete
# store_synchronous: false
# config:
# directory: /mnt/some/other/directory
#
#media_storage_providers:
# - module: file_system
# # Whether to write new local files.
# store_local: false
# # Whether to write new remote media
# store_remote: false
# # Whether to block upload requests waiting for write to this
# # provider to complete
# store_synchronous: false
# config:
# directory: /mnt/some/other/directory
# Directory where in-progress uploads are stored.
#
uploads_path: "%(uploads_path)s"
# The largest allowed upload size in bytes
#
max_upload_size: "10M"
# Maximum number of pixels that will be thumbnailed
#
max_image_pixels: "32M"
# Whether to generate new thumbnails on the fly to precisely match
@ -210,9 +215,11 @@ class ContentRepositoryConfig(Config):
# a new resolution is requested by the client the server will
# generate a new thumbnail. If false the server will pick a thumbnail
# from a precalculated list.
#
dynamic_thumbnails: false
# List of thumbnail to precalculate when an image is uploaded.
# List of thumbnails to precalculate when an image is uploaded.
#
thumbnail_sizes:
- width: 32
height: 32
@ -233,6 +240,7 @@ class ContentRepositoryConfig(Config):
# Is the preview URL API enabled? If enabled, you *must* specify
# an explicit url_preview_ip_range_blacklist of IPs that the spider is
# denied from accessing.
#
url_preview_enabled: False
# List of IP address CIDR ranges that the URL preview spider is denied
@ -243,16 +251,16 @@ class ContentRepositoryConfig(Config):
# synapse to issue arbitrary GET requests to your internal services,
# causing serious security issues.
#
# url_preview_ip_range_blacklist:
# - '127.0.0.0/8'
# - '10.0.0.0/8'
# - '172.16.0.0/12'
# - '192.168.0.0/16'
# - '100.64.0.0/10'
# - '169.254.0.0/16'
# - '::1/128'
# - 'fe80::/64'
# - 'fc00::/7'
#url_preview_ip_range_blacklist:
# - '127.0.0.0/8'
# - '10.0.0.0/8'
# - '172.16.0.0/12'
# - '192.168.0.0/16'
# - '100.64.0.0/10'
# - '169.254.0.0/16'
# - '::1/128'
# - 'fe80::/64'
# - 'fc00::/7'
#
# List of IP address CIDR ranges that the URL preview spider is allowed
# to access even if they are specified in url_preview_ip_range_blacklist.
@ -260,8 +268,8 @@ class ContentRepositoryConfig(Config):
# target IP ranges - e.g. for enabling URL previews for a specific private
# website only visible in your network.
#
# url_preview_ip_range_whitelist:
# - '192.168.1.1'
#url_preview_ip_range_whitelist:
# - '192.168.1.1'
# Optional list of URL matches that the URL preview spider is
# denied from accessing. You should use url_preview_ip_range_blacklist
@ -279,26 +287,25 @@ class ContentRepositoryConfig(Config):
# specified component matches for a given list item succeed, the URL is
# blacklisted.
#
# url_preview_url_blacklist:
# # blacklist any URL with a username in its URI
# - username: '*'
#url_preview_url_blacklist:
# # blacklist any URL with a username in its URI
# - username: '*'
#
# # blacklist all *.google.com URLs
# - netloc: 'google.com'
# - netloc: '*.google.com'
# # blacklist all *.google.com URLs
# - netloc: 'google.com'
# - netloc: '*.google.com'
#
# # blacklist all plain HTTP URLs
# - scheme: 'http'
# # blacklist all plain HTTP URLs
# - scheme: 'http'
#
# # blacklist http(s)://www.acme.com/foo
# - netloc: 'www.acme.com'
# path: '/foo'
# # blacklist http(s)://www.acme.com/foo
# - netloc: 'www.acme.com'
# path: '/foo'
#
# # blacklist any URL with a literal IPv4 address
# - netloc: '^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$'
# # blacklist any URL with a literal IPv4 address
# - netloc: '^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$'
# The largest allowed URL preview spidering size in bytes
max_spider_size: "10M"
""" % locals()

View file

@ -20,12 +20,37 @@ from ._base import Config, ConfigError
class RoomDirectoryConfig(Config):
def read_config(self, config):
alias_creation_rules = config["alias_creation_rules"]
alias_creation_rules = config.get("alias_creation_rules")
self._alias_creation_rules = [
_AliasRule(rule)
for rule in alias_creation_rules
]
if alias_creation_rules is not None:
self._alias_creation_rules = [
_RoomDirectoryRule("alias_creation_rules", rule)
for rule in alias_creation_rules
]
else:
self._alias_creation_rules = [
_RoomDirectoryRule(
"alias_creation_rules", {
"action": "allow",
}
)
]
room_list_publication_rules = config.get("room_list_publication_rules")
if room_list_publication_rules is not None:
self._room_list_publication_rules = [
_RoomDirectoryRule("room_list_publication_rules", rule)
for rule in room_list_publication_rules
]
else:
self._room_list_publication_rules = [
_RoomDirectoryRule(
"room_list_publication_rules", {
"action": "allow",
}
)
]
self.allow_non_federated_in_public_rooms = config.get(
"allow_non_federated_in_public_rooms", True,
@ -37,67 +62,146 @@ class RoomDirectoryConfig(Config):
# on this server.
#
# The format of this option is a list of rules that contain globs that
# match against user_id and the new alias (fully qualified with server
# name). The action in the first rule that matches is taken, which can
# currently either be "allow" or "deny".
# match against user_id, room_id and the new alias (fully qualified with
# server name). The action in the first rule that matches is taken,
# which can currently either be "allow" or "deny".
#
# If no rules match the request is denied.
alias_creation_rules:
- user_id: "*"
alias: "*"
action: allow
# Missing user_id/room_id/alias fields default to "*".
#
# If no rules match the request is denied. An empty list means no one
# can create aliases.
#
# Options for the rules include:
#
# user_id: Matches against the creator of the alias
# alias: Matches against the alias being created
# room_id: Matches against the room ID the alias is being pointed at
# action: Whether to "allow" or "deny" the request if the rule matches
#
# The default is:
#
#alias_creation_rules:
# - user_id: "*"
# alias: "*"
# room_id: "*"
# action: allow
# The `room_list_publication_rules` option controls who can publish and
# which rooms can be published in the public room list.
#
# The format of this option is the same as that for
# `alias_creation_rules`.
#
# If the room has one or more aliases associated with it, only one of
# the aliases needs to match the alias rule. If there are no aliases
# then only rules with `alias: *` match.
#
# If no rules match the request is denied. An empty list means no one
# can publish rooms.
#
# Options for the rules include:
#
# user_id: Matches agaisnt the creator of the alias
# room_id: Matches against the room ID being published
# alias: Matches against any current local or canonical aliases
# associated with the room
# action: Whether to "allow" or "deny" the request if the rule matches
#
# The default is:
#
#room_list_publication_rules:
# - user_id: "*"
# alias: "*"
# room_id: "*"
# action: allow
# Specify whether rooms that only allow local users to join should be
# shown in the federation public room directory.
#
#
# Note that this does not affect the room directory shown to users on
# this homeserver, only those on other homeservers.
#
#allow_non_federated_in_public_rooms: True
"""
def is_alias_creation_allowed(self, user_id, alias):
def is_alias_creation_allowed(self, user_id, room_id, alias):
"""Checks if the given user is allowed to create the given alias
Args:
user_id (str)
room_id (str)
alias (str)
Returns:
boolean: True if user is allowed to crate the alias
"""
for rule in self._alias_creation_rules:
if rule.matches(user_id, alias):
if rule.matches(user_id, room_id, [alias]):
return rule.action == "allow"
return False
def is_publishing_room_allowed(self, user_id, room_id, aliases):
"""Checks if the given user is allowed to publish the room
Args:
user_id (str)
room_id (str)
aliases (list[str]): any local aliases associated with the room
Returns:
boolean: True if user can publish room
"""
for rule in self._room_list_publication_rules:
if rule.matches(user_id, room_id, aliases):
return rule.action == "allow"
return False
class _AliasRule(object):
def __init__(self, rule):
class _RoomDirectoryRule(object):
"""Helper class to test whether a room directory action is allowed, like
creating an alias or publishing a room.
"""
def __init__(self, option_name, rule):
"""
Args:
option_name (str): Name of the config option this rule belongs to
rule (dict): The rule as specified in the config
"""
action = rule["action"]
user_id = rule["user_id"]
alias = rule["alias"]
user_id = rule.get("user_id", "*")
room_id = rule.get("room_id", "*")
alias = rule.get("alias", "*")
if action in ("allow", "deny"):
self.action = action
else:
raise ConfigError(
"alias_creation_rules rules can only have action of 'allow'"
" or 'deny'"
"%s rules can only have action of 'allow'"
" or 'deny'" % (option_name,)
)
self._alias_matches_all = alias == "*"
try:
self._user_id_regex = glob_to_regex(user_id)
self._alias_regex = glob_to_regex(alias)
self._room_id_regex = glob_to_regex(room_id)
except Exception as e:
raise ConfigError("Failed to parse glob into regex: %s", e)
def matches(self, user_id, alias):
"""Tests if this rule matches the given user_id and alias.
def matches(self, user_id, room_id, aliases):
"""Tests if this rule matches the given user_id, room_id and aliases.
Args:
user_id (str)
alias (str)
room_id (str)
aliases (list[str]): The associated aliases to the room. Will be a
single element for testing alias creation, and can be empty for
testing room publishing.
Returns:
boolean
@ -107,7 +211,22 @@ class _AliasRule(object):
if not self._user_id_regex.match(user_id):
return False
if not self._alias_regex.match(alias):
if not self._room_id_regex.match(room_id):
return False
return True
# We only have alias checks left, so we can short circuit if the alias
# rule matches everything.
if self._alias_matches_all:
return True
# If we are not given any aliases then this rule only matches if the
# alias glob matches all aliases, which we checked above.
if not aliases:
return False
# Otherwise, we just need one alias to match
for alias in aliases:
if self._alias_regex.match(alias):
return True
return False

View file

@ -1,55 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 Ericsson
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class SAML2Config(Config):
"""SAML2 Configuration
Synapse uses pysaml2 libraries for providing SAML2 support
config_path: Path to the sp_conf.py configuration file
idp_redirect_url: Identity provider URL which will redirect
the user back to /login/saml2 with proper info.
sp_conf.py file is something like:
https://github.com/rohe/pysaml2/blob/master/example/sp-repoze/sp_conf.py.example
More information: https://pythonhosted.org/pysaml2/howto/config.html
"""
def read_config(self, config):
saml2_config = config.get("saml2_config", None)
if saml2_config:
self.saml2_enabled = saml2_config.get("enabled", True)
self.saml2_config_path = saml2_config["config_path"]
self.saml2_idp_redirect_url = saml2_config["idp_redirect_url"]
else:
self.saml2_enabled = False
self.saml2_config_path = None
self.saml2_idp_redirect_url = None
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable SAML2 for registration and login. Uses pysaml2
# config_path: Path to the sp_conf.py configuration file
# idp_redirect_url: Identity provider URL which will redirect
# the user back to /login/saml2 with proper info.
# See pysaml2 docs for format of config.
#saml2_config:
# enabled: true
# config_path: "%s/sp_conf.py"
# idp_redirect_url: "http://%s/idp"
""" % (config_dir_path, server_name)

View file

@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config, ConfigError
class SAML2Config(Config):
def read_config(self, config):
self.saml2_enabled = False
saml2_config = config.get("saml2_config")
if not saml2_config or not saml2_config.get("enabled", True):
return
self.saml2_enabled = True
import saml2.config
self.saml2_sp_config = saml2.config.SPConfig()
self.saml2_sp_config.load(self._default_saml_config_dict())
self.saml2_sp_config.load(saml2_config.get("sp_config", {}))
config_path = saml2_config.get("config_path", None)
if config_path is not None:
self.saml2_sp_config.load_file(config_path)
def _default_saml_config_dict(self):
import saml2
public_baseurl = self.public_baseurl
if public_baseurl is None:
raise ConfigError(
"saml2_config requires a public_baseurl to be set"
)
metadata_url = public_baseurl + "_matrix/saml2/metadata.xml"
response_url = public_baseurl + "_matrix/saml2/authn_response"
return {
"entityid": metadata_url,
"service": {
"sp": {
"endpoints": {
"assertion_consumer_service": [
(response_url, saml2.BINDING_HTTP_POST),
],
},
"required_attributes": ["uid"],
"optional_attributes": ["mail", "surname", "givenname"],
},
}
}
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable SAML2 for registration and login. Uses pysaml2.
#
# `sp_config` is the configuration for the pysaml2 Service Provider.
# See pysaml2 docs for format of config.
#
# Default values will be used for the 'entityid' and 'service' settings,
# so it is not normally necessary to specify them unless you need to
# override them.
#
#saml2_config:
# sp_config:
# # point this to the IdP's metadata. You can use either a local file or
# # (preferably) a URL.
# metadata:
# #local: ["saml2/idp.xml"]
# remote:
# - url: https://our_idp/metadata.xml
#
# # The rest of sp_config is just used to generate our metadata xml, and you
# # may well not need it, depending on your setup. Alternatively you
# # may need a whole lot more detail - see the pysaml2 docs!
#
# description: ["My awesome SP", "en"]
# name: ["Test SP", "en"]
#
# organization:
# name: Example com
# display_name:
# - ["Example co", "en"]
# url: "http://example.com"
#
# contact_person:
# - given_name: Bob
# sur_name: "the Sysadmin"
# email_address": ["admin@example.com"]
# contact_type": technical
#
# # Instead of putting the config inline as above, you can specify a
# # separate pysaml2 configuration file:
# #
# config_path: "%(config_dir_path)s/sp_conf.py"
""" % {"config_dir_path": config_dir_path}

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd
# Copyright 2017-2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,13 +15,23 @@
# limitations under the License.
import logging
import os.path
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.python_dependencies import DependencyException, check_requirements
from ._base import Config, ConfigError
logger = logging.Logger(__name__)
# by default, we attempt to listen on both '::' *and* '0.0.0.0' because some OSes
# (Windows, macOS, other BSD/Linux where net.ipv6.bindv6only is set) will only listen
# on IPv6 when '::' is set.
#
# We later check for errors when binding to 0.0.0.0 and ignore them if :: is also in
# in the list.
DEFAULT_BIND_ADDRESSES = ['::', '0.0.0.0']
class ServerConfig(Config):
@ -34,7 +44,6 @@ class ServerConfig(Config):
raise ConfigError(str(e))
self.pid_file = self.abspath(config.get("pid_file"))
self.web_client = config["web_client"]
self.web_client_location = config.get("web_client_location", None)
self.soft_file_limit = config["soft_file_limit"]
self.daemonize = config.get("daemonize")
@ -62,6 +71,11 @@ class ServerConfig(Config):
# master, potentially causing inconsistency.
self.enable_media_repo = config.get("enable_media_repo", True)
# whether to enable search. If disabled, new entries will not be inserted
# into the search tables and they will not be indexed. Users will receive
# errors when attempting to search for messages.
self.enable_search = config.get("enable_search", True)
self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
# Whether we should block invites sent to users on this server
@ -77,6 +91,7 @@ class ServerConfig(Config):
self.max_mau_value = config.get(
"max_mau_value", 0,
)
self.mau_stats_only = config.get("mau_stats_only", False)
self.mau_limits_reserved_threepids = config.get(
"mau_limit_reserved_threepids", []
@ -111,27 +126,53 @@ class ServerConfig(Config):
self.public_baseurl += '/'
self.start_pushers = config.get("start_pushers", True)
self.listeners = config.get("listeners", [])
self.listeners = []
for listener in config.get("listeners", []):
if not isinstance(listener.get("port", None), int):
raise ConfigError(
"Listener configuration is lacking a valid 'port' option"
)
if listener.setdefault("tls", False):
# no_tls is not really supported any more, but let's grandfather it in
# here.
if config.get("no_tls", False):
logger.info(
"Ignoring TLS-enabled listener on port %i due to no_tls"
)
continue
for listener in self.listeners:
bind_address = listener.pop("bind_address", None)
bind_addresses = listener.setdefault("bind_addresses", [])
# if bind_address was specified, add it to the list of addresses
if bind_address:
bind_addresses.append(bind_address)
elif not bind_addresses:
bind_addresses.append('')
# if we still have an empty list of addresses, use the default list
if not bind_addresses:
if listener['type'] == 'metrics':
# the metrics listener doesn't support IPv6
bind_addresses.append('0.0.0.0')
else:
bind_addresses.extend(DEFAULT_BIND_ADDRESSES)
self.listeners.append(listener)
if not self.web_client_location:
_warn_if_webclient_configured(self.listeners)
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
bind_port = config.get("bind_port")
if bind_port:
if config.get("no_tls", False):
raise ConfigError("no_tls is incompatible with bind_port")
self.listeners = []
bind_host = config.get("bind_host", "")
gzip_responses = config.get("gzip_responses", True)
names = ["client", "webclient"] if self.web_client else ["client"]
self.listeners.append({
"port": bind_port,
"bind_addresses": [bind_host],
@ -139,7 +180,7 @@ class ServerConfig(Config):
"type": "http",
"resources": [
{
"names": names,
"names": ["client"],
"compress": gzip_responses,
},
{
@ -158,7 +199,7 @@ class ServerConfig(Config):
"type": "http",
"resources": [
{
"names": names,
"names": ["client"],
"compress": gzip_responses,
},
{
@ -174,6 +215,7 @@ class ServerConfig(Config):
"port": manhole,
"bind_addresses": ["127.0.0.1"],
"type": "manhole",
"tls": False,
})
metrics_port = config.get("metrics_port")
@ -197,7 +239,12 @@ class ServerConfig(Config):
]
})
def default_config(self, server_name, **kwargs):
_check_resource_config(self.listeners)
def has_tls_listener(self):
return any(l["tls"] for l in self.listeners)
def default_config(self, server_name, data_dir_path, **kwargs):
_, bind_port = parse_and_validate_server_name(server_name)
if bind_port is not None:
unsecure_port = bind_port - 400
@ -205,7 +252,7 @@ class ServerConfig(Config):
bind_port = 8448
unsecure_port = 8008
pid_file = self.abspath("homeserver.pid")
pid_file = os.path.join(data_dir_path, "homeserver.pid")
return """\
## Server ##
@ -239,19 +286,20 @@ class ServerConfig(Config):
#
# This setting requires the affinity package to be installed!
#
# cpu_affinity: 0xFFFFFFFF
#cpu_affinity: 0xFFFFFFFF
# Whether to serve a web client from the HTTP/HTTPS root resource.
web_client: True
# The path to the web client which will be served at /_matrix/client/
# if 'webclient' is configured under the 'listeners' configuration.
#
#web_client_location: "/path/to/web/root"
# The root directory to server for the above web client.
# If left undefined, synapse will serve the matrix-angular-sdk web client.
# Make sure matrix-angular-sdk is installed with pip if web_client is True
# and web_client_location is undefined
# web_client_location: "/path/to/web/root"
# The public-facing base URL for the client API (not including _matrix/...)
# public_baseurl: https://example.com:8448/
# The public-facing base URL that clients use to access this HS
# (not including _matrix/...). This is the same URL a user would
# enter into the 'custom HS URL' field on their client. If you
# use synapse with a reverse proxy, this should be the URL to reach
# synapse via the proxy.
#
#public_baseurl: https://example.com/
# Set the soft limit on the number of file descriptors synapse can use
# Zero is used to indicate synapse should set the soft limit to the
@ -262,15 +310,25 @@ class ServerConfig(Config):
use_presence: true
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
# gc_thresholds: [700, 10, 10]
#
#gc_thresholds: [700, 10, 10]
# Set the limit on the returned events in the timeline in the get
# and sync operations. The default value is -1, means no upper limit.
# filter_timeline_limit: 5000
#
#filter_timeline_limit: 5000
# Whether room invites to users on this server should be blocked
# (except those sent by local server admins). The default is False.
# block_non_admin_invites: True
#
#block_non_admin_invites: True
# Room searching
#
# If disabled, new messages will not be indexed for searching and users
# will receive errors when searching for messages. Defaults to enabled.
#
#enable_search: false
# Restrict federation to the following whitelist of domains.
# N.B. we recommend also firewalling your federation listener to limit
@ -278,107 +336,145 @@ class ServerConfig(Config):
# purely on this application-layer restriction. If not specified, the
# default is to whitelist everything.
#
# federation_domain_whitelist:
#federation_domain_whitelist:
# - lon.example.com
# - nyc.example.com
# - syd.example.com
# List of ports that Synapse should listen on, their purpose and their
# configuration.
#
# Options for each listener include:
#
# port: the TCP port to bind to
#
# bind_addresses: a list of local addresses to listen on. The default is
# 'all local interfaces'.
#
# type: the type of listener. Normally 'http', but other valid options are:
# 'manhole' (see docs/manhole.md),
# 'metrics' (see docs/metrics-howto.rst),
# 'replication' (see docs/workers.rst).
#
# tls: set to true to enable TLS for this listener. Will use the TLS
# key/cert specified in tls_private_key_path / tls_certificate_path.
#
# x_forwarded: Only valid for an 'http' listener. Set to true to use the
# X-Forwarded-For header as the client IP. Useful when Synapse is
# behind a reverse-proxy.
#
# resources: Only valid for an 'http' listener. A list of resources to host
# on this port. Options for each resource are:
#
# names: a list of names of HTTP resources. See below for a list of
# valid resource names.
#
# compress: set to true to enable HTTP comression for this resource.
#
# additional_resources: Only valid for an 'http' listener. A map of
# additional endpoints which should be loaded via dynamic modules.
#
# Valid resource names are:
#
# client: the client-server API (/_matrix/client). Also implies 'media' and
# 'static'.
#
# consent: user consent forms (/_matrix/consent). See
# docs/consent_tracking.md.
#
# federation: the server-server API (/_matrix/federation). Also implies
# 'media', 'keys', 'openid'
#
# keys: the key discovery API (/_matrix/keys).
#
# media: the media API (/_matrix/media).
#
# metrics: the metrics interface. See docs/metrics-howto.rst.
#
# openid: OpenID authentication.
#
# replication: the HTTP replication API (/_synapse/replication). See
# docs/workers.rst.
#
# static: static resources under synapse/static (/_matrix/static). (Mostly
# useful for 'fallback authentication'.)
#
# webclient: A web client. Requires web_client_location to be set.
#
listeners:
# Main HTTPS listener
# For when matrix traffic is sent directly to synapse.
-
# The port to listen for HTTPS requests on.
port: %(bind_port)s
# TLS-enabled listener: for when matrix traffic is sent directly to synapse.
#
# Disabled by default. To enable it, uncomment the following. (Note that you
# will also need to give Synapse a TLS key and certificate: see the TLS section
# below.)
#
#- port: %(bind_port)s
# type: http
# tls: true
# resources:
# - names: [client, federation]
# Local addresses to listen on.
# On Linux and Mac OS, `::` will listen on all IPv4 and IPv6
# addresses by default. For most other OSes, this will only listen
# on IPv6.
bind_addresses:
- '::'
- '0.0.0.0'
# This is a 'http' listener, allows us to specify 'resources'.
type: http
tls: true
# Use the X-Forwarded-For (XFF) header as the client IP and not the
# actual client IP.
x_forwarded: false
# List of HTTP resources to serve on this listener.
resources:
-
# List of resources to host on this listener.
names:
- client # The client-server APIs, both v1 and v2
- webclient # The bundled webclient.
# Should synapse compress HTTP responses to clients that support it?
# This should be disabled if running synapse behind a load balancer
# that can do automatic compression.
compress: true
- names: [federation] # Federation APIs
compress: false
# optional list of additional endpoints which can be loaded via
# dynamic modules
# additional_resources:
# "/_matrix/my/custom/endpoint":
# module: my_module.CustomRequestHandler
# config: {}
# Unsecure HTTP listener,
# For when matrix traffic passes through loadbalancer that unwraps TLS.
# Unsecure HTTP listener: for when matrix traffic passes through a reverse proxy
# that unwraps TLS.
#
# If you plan to use a reverse proxy, please see
# https://github.com/matrix-org/synapse/blob/master/docs/reverse_proxy.rst.
#
- port: %(unsecure_port)s
tls: false
bind_addresses: ['::', '0.0.0.0']
bind_addresses: ['::1', '127.0.0.1']
type: http
x_forwarded: false
x_forwarded: true
resources:
- names: [client, webclient]
compress: true
- names: [federation]
- names: [client, federation]
compress: false
# example additonal_resources:
#
#additional_resources:
# "/_matrix/my/custom/endpoint":
# module: my_module.CustomRequestHandler
# config: {}
# Turn on the twisted ssh manhole service on localhost on the given
# port.
# - port: 9000
# bind_addresses: ['::1', '127.0.0.1']
# type: manhole
#
#- port: 9000
# bind_addresses: ['::1', '127.0.0.1']
# type: manhole
# Homeserver blocking
#
# How to reach the server admin, used in ResourceLimitError
# admin_contact: 'mailto:admin@server.com'
#
# Global block config
#
# hs_disabled: False
# hs_disabled_message: 'Human readable reason for why the HS is blocked'
# hs_disabled_limit_type: 'error code(str), to help clients decode reason'
#
# Monthly Active User Blocking
#
# Enables monthly active user checking
# limit_usage_by_mau: False
# max_mau_value: 50
# mau_trial_days: 2
#
# Sometimes the server admin will want to ensure certain accounts are
# never blocked by mau checking. These accounts are specified here.
#
# mau_limit_reserved_threepids:
# - medium: 'email'
# address: 'reserved_user@example.com'
## Homeserver blocking ##
# How to reach the server admin, used in ResourceLimitError
#
#admin_contact: 'mailto:admin@server.com'
# Global blocking
#
#hs_disabled: False
#hs_disabled_message: 'Human readable reason for why the HS is blocked'
#hs_disabled_limit_type: 'error code(str), to help clients decode reason'
# Monthly Active User Blocking
#
#limit_usage_by_mau: False
#max_mau_value: 50
#mau_trial_days: 2
# If enabled, the metrics for the number of monthly active users will
# be populated, however no one will be limited. If limit_usage_by_mau
# is true, this is implied to be true.
#
#mau_stats_only: False
# Sometimes the server admin will want to ensure certain accounts are
# never blocked by mau checking. These accounts are specified here.
#
#mau_limit_reserved_threepids:
# - medium: 'email'
# address: 'reserved_user@example.com'
""" % locals()
def read_arguments(self, args):
@ -404,19 +500,18 @@ class ServerConfig(Config):
" service on the given port.")
def is_threepid_reserved(config, threepid):
def is_threepid_reserved(reserved_threepids, threepid):
"""Check the threepid against the reserved threepid config
Args:
config(ServerConfig) - to access server config attributes
reserved_threepids([dict]) - list of reserved threepids
threepid(dict) - The threepid to test for
Returns:
boolean Is the threepid undertest reserved_user
"""
for tp in config.mau_limits_reserved_threepids:
if (threepid['medium'] == tp['medium']
and threepid['address'] == tp['address']):
for tp in reserved_threepids:
if (threepid['medium'] == tp['medium'] and threepid['address'] == tp['address']):
return True
return False
@ -436,3 +531,53 @@ def read_gc_thresholds(thresholds):
raise ConfigError(
"Value of `gc_threshold` must be a list of three integers if set"
)
NO_MORE_WEB_CLIENT_WARNING = """
Synapse no longer includes a web client. To enable a web client, configure
web_client_location. To remove this warning, remove 'webclient' from the 'listeners'
configuration.
"""
def _warn_if_webclient_configured(listeners):
for listener in listeners:
for res in listener.get("resources", []):
for name in res.get("names", []):
if name == 'webclient':
logger.warning(NO_MORE_WEB_CLIENT_WARNING)
return
KNOWN_RESOURCES = (
'client',
'consent',
'federation',
'keys',
'media',
'metrics',
'openid',
'replication',
'static',
'webclient',
)
def _check_resource_config(listeners):
resource_names = set(
res_name
for listener in listeners
for res in listener.get("resources", [])
for res_name in res.get("names", [])
)
for resource in resource_names:
if resource not in KNOWN_RESOURCES:
raise ConfigError(
"Unknown listener resource '%s'" % (resource, )
)
if resource == "consent":
try:
check_requirements('resources.consent')
except DependencyException as e:
raise ConfigError(e.message)

View file

@ -30,11 +30,11 @@ DEFAULT_CONFIG = """\
# It's also possible to override the room name, the display name of the
# "notices" user, and the avatar for the user.
#
# server_notices:
# system_mxid_localpart: notices
# system_mxid_display_name: "Server Notices"
# system_mxid_avatar_url: "mxc://server.com/oumMVlgDnLYFaPVkExemNVVZ"
# room_name: "Server Notices"
#server_notices:
# system_mxid_localpart: notices
# system_mxid_display_name: "Server Notices"
# system_mxid_avatar_url: "mxc://server.com/oumMVlgDnLYFaPVkExemNVVZ"
# room_name: "Server Notices"
"""

View file

@ -28,8 +28,8 @@ class SpamCheckerConfig(Config):
def default_config(self, **kwargs):
return """\
# spam_checker:
# module: "my_custom_project.SuperSpamChecker"
# config:
# example_option: 'things'
#spam_checker:
# module: "my_custom_project.SuperSpamChecker"
# config:
# example_option: 'things'
"""

View file

@ -13,51 +13,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import subprocess
import warnings
from datetime import datetime
from hashlib import sha256
from unpaddedbase64 import encode_base64
from OpenSSL import crypto
from ._base import Config
from synapse.config._base import Config, ConfigError
GENERATE_DH_PARAMS = False
logger = logging.getLogger(__name__)
class TlsConfig(Config):
def read_config(self, config):
self.tls_certificate = self.read_tls_certificate(
config.get("tls_certificate_path")
acme_config = config.get("acme", None)
if acme_config is None:
acme_config = {}
self.acme_enabled = acme_config.get("enabled", False)
self.acme_url = acme_config.get(
"url", u"https://acme-v01.api.letsencrypt.org/directory"
)
self.tls_certificate_file = config.get("tls_certificate_path")
self.acme_port = acme_config.get("port", 80)
self.acme_bind_addresses = acme_config.get("bind_addresses", ['::', '0.0.0.0'])
self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30)
self.acme_domain = acme_config.get("domain", config.get("server_name"))
self.no_tls = config.get("no_tls", False)
self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
if self.no_tls:
self.tls_private_key = None
else:
self.tls_private_key = self.read_tls_private_key(
config.get("tls_private_key_path")
)
if self.has_tls_listener():
if not self.tls_certificate_file:
raise ConfigError(
"tls_certificate_path must be specified if TLS-enabled listeners are "
"configured."
)
if not self.tls_private_key_file:
raise ConfigError(
"tls_certificate_path must be specified if TLS-enabled listeners are "
"configured."
)
self.tls_dh_params_path = self.check_file(
config.get("tls_dh_params_path"), "tls_dh_params"
)
self._original_tls_fingerprints = config.get("tls_fingerprints", [])
self.tls_fingerprints = config["tls_fingerprints"]
if self._original_tls_fingerprints is None:
self._original_tls_fingerprints = []
# Check that our own certificate is included in the list of fingerprints
# and include it if it is not.
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1,
self.tls_certificate
)
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
if sha256_fingerprint not in sha256_fingerprints:
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
self.tls_fingerprints = list(self._original_tls_fingerprints)
# This config option applies to non-federation HTTP clients
# (e.g. for talking to recaptcha, identity servers, and such)
@ -67,29 +74,176 @@ class TlsConfig(Config):
"use_insecure_ssl_client_just_for_testing_do_not_use"
)
self.tls_certificate = None
self.tls_private_key = None
def is_disk_cert_valid(self, allow_self_signed=True):
"""
Is the certificate we have on disk valid, and if so, for how long?
Args:
allow_self_signed (bool): Should we allow the certificate we
read to be self signed?
Returns:
int: Days remaining of certificate validity.
None: No certificate exists.
"""
if not os.path.exists(self.tls_certificate_file):
return None
try:
with open(self.tls_certificate_file, 'rb') as f:
cert_pem = f.read()
except Exception:
logger.exception("Failed to read existing certificate off disk!")
raise
try:
tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
except Exception:
logger.exception("Failed to parse existing certificate off disk!")
raise
if not allow_self_signed:
if tls_certificate.get_subject() == tls_certificate.get_issuer():
raise ValueError(
"TLS Certificate is self signed, and this is not permitted"
)
# YYYYMMDDhhmmssZ -- in UTC
expires_on = datetime.strptime(
tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ"
)
now = datetime.utcnow()
days_remaining = (expires_on - now).days
return days_remaining
def read_certificate_from_disk(self, require_cert_and_key):
"""
Read the certificates and private key from disk.
Args:
require_cert_and_key (bool): set to True to throw an error if the certificate
and key file are not given
"""
if require_cert_and_key:
self.tls_private_key = self.read_tls_private_key()
self.tls_certificate = self.read_tls_certificate()
elif self.tls_certificate_file:
# we only need the certificate for the tls_fingerprints. Reload it if we
# can, but it's not a fatal error if we can't.
try:
self.tls_certificate = self.read_tls_certificate()
except Exception as e:
logger.info(
"Unable to read TLS certificate (%s). Ignoring as no "
"tls listeners enabled.", e,
)
self.tls_fingerprints = list(self._original_tls_fingerprints)
if self.tls_certificate:
# Check that our own certificate is included in the list of fingerprints
# and include it if it is not.
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, self.tls_certificate
)
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
if sha256_fingerprint not in sha256_fingerprints:
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
def default_config(self, config_dir_path, server_name, **kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
tls_certificate_path = base_key_name + ".tls.crt"
tls_private_key_path = base_key_name + ".tls.key"
tls_dh_params_path = base_key_name + ".tls.dh"
return """\
# PEM encoded X509 certificate for TLS.
# You can replace the self-signed certificate that synapse
# autogenerates on launch with your own SSL certificate + key pair
# if you like. Any required intermediary certificates can be
# appended after the primary certificate in hierarchical order.
tls_certificate_path: "%(tls_certificate_path)s"
# this is to avoid the max line length. Sorrynotsorry
proxypassline = (
'ProxyPass /.well-known/acme-challenge '
'http://localhost:8009/.well-known/acme-challenge'
)
# PEM encoded private key for TLS
tls_private_key_path: "%(tls_private_key_path)s"
return (
"""\
## TLS ##
# PEM dh parameters for ephemeral keys
tls_dh_params_path: "%(tls_dh_params_path)s"
# PEM-encoded X509 certificate for TLS.
# This certificate, as of Synapse 1.0, will need to be a valid and verifiable
# certificate, signed by a recognised Certificate Authority.
#
# See 'ACME support' below to enable auto-provisioning this certificate via
# Let's Encrypt.
#
#tls_certificate_path: "%(tls_certificate_path)s"
# Don't bind to the https port
no_tls: False
# PEM-encoded private key for TLS
#
#tls_private_key_path: "%(tls_private_key_path)s"
# ACME support: This will configure Synapse to request a valid TLS certificate
# for your configured `server_name` via Let's Encrypt.
#
# Note that provisioning a certificate in this way requires port 80 to be
# routed to Synapse so that it can complete the http-01 ACME challenge.
# By default, if you enable ACME support, Synapse will attempt to listen on
# port 80 for incoming http-01 challenges - however, this will likely fail
# with 'Permission denied' or a similar error.
#
# There are a couple of potential solutions to this:
#
# * If you already have an Apache, Nginx, or similar listening on port 80,
# you can configure Synapse to use an alternate port, and have your web
# server forward the requests. For example, assuming you set 'port: 8009'
# below, on Apache, you would write:
#
# %(proxypassline)s
#
# * Alternatively, you can use something like `authbind` to give Synapse
# permission to listen on port 80.
#
acme:
# ACME support is disabled by default. Uncomment the following line
# (and tls_certificate_path and tls_private_key_path above) to enable it.
#
#enabled: true
# Endpoint to use to request certificates. If you only want to test,
# use Let's Encrypt's staging url:
# https://acme-staging.api.letsencrypt.org/directory
#
#url: https://acme-v01.api.letsencrypt.org/directory
# Port number to listen on for the HTTP-01 challenge. Change this if
# you are forwarding connections through Apache/Nginx/etc.
#
#port: 80
# Local addresses to listen on for incoming connections.
# Again, you may want to change this if you are forwarding connections
# through Apache/Nginx/etc.
#
#bind_addresses: ['::', '0.0.0.0']
# How many days remaining on a certificate before it is renewed.
#
#reprovision_threshold: 30
# The domain that the certificate should be for. Normally this
# should be the same as your Matrix domain (i.e., 'server_name'), but,
# by putting a file at 'https://<server_name>/.well-known/matrix/server',
# you can delegate incoming traffic to another server. If you do that,
# you should give the target of the delegation here.
#
# For example: if your 'server_name' is 'example.com', but
# 'https://example.com/.well-known/matrix/server' delegates to
# 'matrix.example.com', you should put 'matrix.example.com' here.
#
# If not set, defaults to your 'server_name'.
#
#domain: matrix.example.com
# List of allowed TLS fingerprints for this server to publish along
# with the signing keys for this server. Other matrix servers that
@ -116,80 +270,44 @@ class TlsConfig(Config):
# openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '='
# or by checking matrix.org/federationtester/api/report?server_name=$host
#
tls_fingerprints: []
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
""" % locals()
#tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
def read_tls_certificate(self, cert_path):
cert_pem = self.read_file(cert_path, "tls_certificate")
return crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
"""
% locals()
)
def read_tls_private_key(self, private_key_path):
private_key_pem = self.read_file(private_key_path, "tls_private_key")
def read_tls_certificate(self):
"""Reads the TLS certificate from the configured file, and returns it
Also checks if it is self-signed, and warns if so
Returns:
OpenSSL.crypto.X509: the certificate
"""
cert_path = self.tls_certificate_file
logger.info("Loading TLS certificate from %s", cert_path)
cert_pem = self.read_file(cert_path, "tls_certificate_path")
cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
# Check if it is self-signed, and issue a warning if so.
if cert.get_issuer() == cert.get_subject():
warnings.warn(
(
"Self-signed TLS certificates will not be accepted by Synapse 1.0. "
"Please either provide a valid certificate, or use Synapse's ACME "
"support to provision one."
)
)
return cert
def read_tls_private_key(self):
"""Reads the TLS private key from the configured file, and returns it
Returns:
OpenSSL.crypto.PKey: the private key
"""
private_key_path = self.tls_private_key_file
logger.info("Loading TLS key from %s", private_key_path)
private_key_pem = self.read_file(private_key_path, "tls_private_key_path")
return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem)
def generate_files(self, config):
tls_certificate_path = config["tls_certificate_path"]
tls_private_key_path = config["tls_private_key_path"]
tls_dh_params_path = config["tls_dh_params_path"]
if not self.path_exists(tls_private_key_path):
with open(tls_private_key_path, "wb") as private_key_file:
tls_private_key = crypto.PKey()
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
private_key_pem = crypto.dump_privatekey(
crypto.FILETYPE_PEM, tls_private_key
)
private_key_file.write(private_key_pem)
else:
with open(tls_private_key_path) as private_key_file:
private_key_pem = private_key_file.read()
tls_private_key = crypto.load_privatekey(
crypto.FILETYPE_PEM, private_key_pem
)
if not self.path_exists(tls_certificate_path):
with open(tls_certificate_path, "wb") as certificate_file:
cert = crypto.X509()
subject = cert.get_subject()
subject.CN = config["server_name"]
cert.set_serial_number(1000)
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60)
cert.set_issuer(cert.get_subject())
cert.set_pubkey(tls_private_key)
cert.sign(tls_private_key, 'sha256')
cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
certificate_file.write(cert_pem)
if not self.path_exists(tls_dh_params_path):
if GENERATE_DH_PARAMS:
subprocess.check_call([
"openssl", "dhparam",
"-outform", "PEM",
"-out", tls_dh_params_path,
"2048"
])
else:
with open(tls_dh_params_path, "w") as dh_params_file:
dh_params_file.write(
"2048-bit DH parameters taken from rfc3526\n"
"-----BEGIN DH PARAMETERS-----\n"
"MIIBCAKCAQEA///////////JD9qiIWjC"
"NMTGYouA3BzRKQJOCIpnzHQCC76mOxOb\n"
"IlFKCHmONATd75UZs806QxswKwpt8l8U"
"N0/hNW1tUcJF5IW1dmJefsb0TELppjft\n"
"awv/XLb0Brft7jhr+1qJn6WunyQRfEsf"
"5kkoZlHs5Fs9wgB8uKFjvwWY2kg2HFXT\n"
"mmkWP6j9JM9fg2VdI9yjrZYcYvNWIIVS"
"u57VKQdwlpZtZww1Tkq8mATxdGwIyhgh\n"
"fDKQXkYuNs474553LBgOhgObJ4Oi7Aei"
"j7XFXfBvTFLJ3ivL9pVYFxg5lUl86pVq\n"
"5RXSJhiY+gUQFXKOWoqsqmj/////////"
"/wIBAg==\n"
"-----END DH PARAMETERS-----\n"
)

View file

@ -40,5 +40,5 @@ class UserDirectoryConfig(Config):
# on your database to tell it to rebuild the user_directory search indexes.
#
#user_directory:
# search_all_users: false
# search_all_users: false
"""

View file

@ -27,20 +27,24 @@ class VoipConfig(Config):
def default_config(self, **kwargs):
return """\
## Turn ##
## TURN ##
# The public URIs of the TURN server to give to clients
#
#turn_uris: []
# The shared secret used to compute passwords for the TURN server
#
#turn_shared_secret: "YOUR_SHARED_SECRET"
# The Username and password if the TURN server needs them and
# does not use a token
#
#turn_username: "TURNSERVER_USERNAME"
#turn_password: "TURNSERVER_PASSWORD"
# How long generated TURN credentials last
#
turn_user_lifetime: "1h"
# Whether guests should be allowed to use the TURN server.
@ -48,5 +52,6 @@ class VoipConfig(Config):
# However, it does introduce a slight security risk as it allows users to
# connect to arbitrary endpoints without having first signed up for a
# valid account (e.g. by passing a CAPTCHA).
#
turn_allow_guests: True
"""

View file

@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,12 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from zope.interface import implementer
from OpenSSL import SSL, crypto
from twisted.internet._sslverify import _defaultCurveName
from twisted.internet.abstract import isIPAddress, isIPv6Address
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
from twisted.internet.ssl import CertificateOptions, ContextFactory
from twisted.python.failure import Failure
@ -42,12 +45,12 @@ class ServerContextFactory(ContextFactory):
logger.exception("Failed to enable elliptic curve for TLS")
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
context.use_certificate_chain_file(config.tls_certificate_file)
context.use_privatekey(config.tls_private_key)
if not config.no_tls:
context.use_privatekey(config.tls_private_key)
context.load_tmp_dh(config.tls_dh_params_path)
context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH")
# https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
context.set_cipher_list(
"ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1"
)
def getContext(self):
return self._context
@ -96,11 +99,15 @@ class ClientTLSOptions(object):
def __init__(self, hostname, ctx):
self._ctx = ctx
self._hostname = hostname
self._hostnameBytes = _idnaBytes(hostname)
ctx.set_info_callback(
_tolerateErrors(self._identityVerifyingInfoCallback)
)
if isIPAddress(hostname) or isIPv6Address(hostname):
self._hostnameBytes = hostname.encode('ascii')
self._sendSNI = False
else:
self._hostnameBytes = _idnaBytes(hostname)
self._sendSNI = True
ctx.set_info_callback(_tolerateErrors(self._identityVerifyingInfoCallback))
def clientConnectionForTLS(self, tlsProtocol):
context = self._ctx
@ -109,7 +116,9 @@ class ClientTLSOptions(object):
return connection
def _identityVerifyingInfoCallback(self, connection, where, ret):
if where & SSL.SSL_CB_HANDSHAKE_START:
# Literal IPv4 and IPv6 addresses are not permitted
# as host names according to the RFCs
if where & SSL.SSL_CB_HANDSHAKE_START and self._sendSNI:
connection.set_tlsext_host_name(self._hostnameBytes)
@ -119,10 +128,8 @@ class ClientTLSOptionsFactory(object):
def __init__(self, config):
# We don't use config options yet
pass
self._options = CertificateOptions(verify=False)
def get_options(self, host):
return ClientTLSOptions(
host,
CertificateOptions(verify=False).getContext()
)
# Use _makeContext so that we get a fresh OpenSSL CTX each time.
return ClientTLSOptions(host, self._options._makeContext())

View file

@ -23,14 +23,14 @@ from signedjson.sign import sign_json
from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.errors import Codes, SynapseError
from synapse.events.utils import prune_event
from synapse.events.utils import prune_event, prune_event_dict
logger = logging.getLogger(__name__)
def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
"""Check whether the hash for this PDU matches the contents"""
name, expected_hash = compute_content_hash(event, hash_algorithm)
name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm)
logger.debug("Expecting hash: %s", encode_base64(expected_hash))
# some malformed events lack a 'hashes'. Protect against it being missing
@ -59,35 +59,70 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
return message_hash_bytes == expected_hash
def compute_content_hash(event, hash_algorithm):
event_json = event.get_pdu_json()
event_json.pop("age_ts", None)
event_json.pop("unsigned", None)
event_json.pop("signatures", None)
event_json.pop("hashes", None)
event_json.pop("outlier", None)
event_json.pop("destinations", None)
def compute_content_hash(event_dict, hash_algorithm):
"""Compute the content hash of an event, which is the hash of the
unredacted event.
event_json_bytes = encode_canonical_json(event_json)
Args:
event_dict (dict): The unredacted event as a dict
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
to hash the event
Returns:
tuple[str, bytes]: A tuple of the name of hash and the hash as raw
bytes.
"""
event_dict = dict(event_dict)
event_dict.pop("age_ts", None)
event_dict.pop("unsigned", None)
event_dict.pop("signatures", None)
event_dict.pop("hashes", None)
event_dict.pop("outlier", None)
event_dict.pop("destinations", None)
event_json_bytes = encode_canonical_json(event_dict)
hashed = hash_algorithm(event_json_bytes)
return (hashed.name, hashed.digest())
def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
"""Computes the event reference hash. This is the hash of the redacted
event.
Args:
event (FrozenEvent)
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
to hash the event
Returns:
tuple[str, bytes]: A tuple of the name of hash and the hash as raw
bytes.
"""
tmp_event = prune_event(event)
event_json = tmp_event.get_pdu_json()
event_json.pop("signatures", None)
event_json.pop("age_ts", None)
event_json.pop("unsigned", None)
event_json_bytes = encode_canonical_json(event_json)
event_dict = tmp_event.get_pdu_json()
event_dict.pop("signatures", None)
event_dict.pop("age_ts", None)
event_dict.pop("unsigned", None)
event_json_bytes = encode_canonical_json(event_dict)
hashed = hash_algorithm(event_json_bytes)
return (hashed.name, hashed.digest())
def compute_event_signature(event, signature_name, signing_key):
tmp_event = prune_event(event)
redact_json = tmp_event.get_pdu_json()
def compute_event_signature(event_dict, signature_name, signing_key):
"""Compute the signature of the event for the given name and key.
Args:
event_dict (dict): The event as a dict
signature_name (str): The name of the entity signing the event
(typically the server's hostname).
signing_key (syutil.crypto.SigningKey): The key to sign with
Returns:
dict[str, dict[str, str]]: Returns a dictionary in the same format of
an event's signatures field.
"""
redact_json = prune_event_dict(event_dict)
redact_json.pop("age_ts", None)
redact_json.pop("unsigned", None)
logger.debug("Signing event: %s", encode_canonical_json(redact_json))
@ -96,25 +131,25 @@ def compute_event_signature(event, signature_name, signing_key):
return redact_json["signatures"]
def add_hashes_and_signatures(event, signature_name, signing_key,
def add_hashes_and_signatures(event_dict, signature_name, signing_key,
hash_algorithm=hashlib.sha256):
# if hasattr(event, "old_state_events"):
# state_json_bytes = encode_canonical_json(
# [e.event_id for e in event.old_state_events.values()]
# )
# hashed = hash_algorithm(state_json_bytes)
# event.state_hash = {
# hashed.name: encode_base64(hashed.digest())
# }
"""Add content hash and sign the event
name, digest = compute_content_hash(event, hash_algorithm=hash_algorithm)
Args:
event_dict (dict): The event to add hashes to and sign
signature_name (str): The name of the entity signing the event
(typically the server's hostname).
signing_key (syutil.crypto.SigningKey): The key to sign with
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
to hash the event
"""
if not hasattr(event, "hashes"):
event.hashes = {}
event.hashes[name] = encode_base64(digest)
name, digest = compute_content_hash(event_dict, hash_algorithm=hash_algorithm)
event.signatures = compute_event_signature(
event,
event_dict.setdefault("hashes", {})[name] = encode_base64(digest)
event_dict["signatures"] = compute_event_signature(
event_dict,
signature_name=signature_name,
signing_key=signing_key,
)

View file

@ -1,147 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from canonicaljson import json
from twisted.internet import defer, reactor
from twisted.internet.error import ConnectError
from twisted.internet.protocol import Factory
from twisted.names.error import DomainError
from twisted.web.http import HTTPClient
from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util import logcontext
logger = logging.getLogger(__name__)
KEY_API_V1 = b"/_matrix/key/v1/"
@defer.inlineCallbacks
def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1):
"""Fetch the keys for a remote server."""
factory = SynapseKeyClientFactory()
factory.path = path
factory.host = server_name
endpoint = matrix_federation_endpoint(
reactor, server_name, tls_client_options_factory, timeout=30
)
for i in range(5):
try:
with logcontext.PreserveLoggingContext():
protocol = yield endpoint.connect(factory)
server_response, server_certificate = yield protocol.remote_key
defer.returnValue((server_response, server_certificate))
except SynapseKeyClientError as e:
logger.warn("Error getting key for %r: %s", server_name, e)
if e.status.startswith(b"4"):
# Don't retry for 4xx responses.
raise IOError("Cannot get key for %r" % server_name)
except (ConnectError, DomainError) as e:
logger.warn("Error getting key for %r: %s", server_name, e)
except Exception:
logger.exception("Error getting key for %r", server_name)
raise IOError("Cannot get key for %r" % server_name)
class SynapseKeyClientError(Exception):
"""The key wasn't retrieved from the remote server."""
status = None
pass
class SynapseKeyClientProtocol(HTTPClient):
"""Low level HTTPS client which retrieves an application/json response from
the server and extracts the X.509 certificate for the remote peer from the
SSL connection."""
timeout = 30
def __init__(self):
self.remote_key = defer.Deferred()
self.host = None
self._peer = None
def connectionMade(self):
self._peer = self.transport.getPeer()
logger.debug("Connected to %s", self._peer)
if not isinstance(self.path, bytes):
self.path = self.path.encode('ascii')
if not isinstance(self.host, bytes):
self.host = self.host.encode('ascii')
self.sendCommand(b"GET", self.path)
if self.host:
self.sendHeader(b"Host", self.host)
self.endHeaders()
self.timer = reactor.callLater(
self.timeout,
self.on_timeout
)
def errback(self, error):
if not self.remote_key.called:
self.remote_key.errback(error)
def callback(self, result):
if not self.remote_key.called:
self.remote_key.callback(result)
def handleStatus(self, version, status, message):
if status != b"200":
# logger.info("Non-200 response from %s: %s %s",
# self.transport.getHost(), status, message)
error = SynapseKeyClientError(
"Non-200 response %r from %r" % (status, self.host)
)
error.status = status
self.errback(error)
self.transport.abortConnection()
def handleResponse(self, response_body_bytes):
try:
json_response = json.loads(response_body_bytes)
except ValueError:
# logger.info("Invalid JSON response from %s",
# self.transport.getHost())
self.transport.abortConnection()
return
certificate = self.transport.getPeerCertificate()
self.callback((json_response, certificate))
self.transport.abortConnection()
self.timer.cancel()
def on_timeout(self):
logger.debug(
"Timeout waiting for response from %s: %s",
self.host, self._peer,
)
self.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection()
class SynapseKeyClientFactory(Factory):
def protocol(self):
protocol = SynapseKeyClientProtocol()
protocol.path = self.path
protocol.host = self.host
return protocol

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd.
# Copyright 2017, 2018 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import logging
from collections import namedtuple
@ -32,13 +31,11 @@ from signedjson.sign import (
signature_ids,
verify_signed_json,
)
from unpaddedbase64 import decode_base64, encode_base64
from unpaddedbase64 import decode_base64
from OpenSSL import crypto
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyclient import fetch_server_key
from synapse.api.errors import Codes, RequestSendFailed, SynapseError
from synapse.util import logcontext, unwrapFirstError
from synapse.util.logcontext import (
LoggingContext,
@ -395,32 +392,13 @@ class Keyring(object):
@defer.inlineCallbacks
def get_keys_from_server(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(server_name, key_ids):
keys = None
try:
keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids
)
except Exception as e:
logger.info(
"Unable to get key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e),
)
if not keys:
keys = yield self.get_server_verify_key_v1_direct(
server_name, key_ids
)
keys = {server_name: keys}
defer.returnValue(keys)
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
run_in_background(get_key, server_name, key_ids)
run_in_background(
self.get_server_verify_key_v2_direct,
server_name,
key_ids,
)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
@ -524,34 +502,16 @@ class Keyring(object):
if requested_key_id in keys:
continue
(response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_client_options_factory,
path=("/_matrix/key/v2/server/%s" % (
urllib.parse.quote(requested_key_id),
)).encode("ascii"),
response = yield self.client.get_json(
destination=server_name,
path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id),
ignore_backoff=True,
)
if (u"signatures" not in response
or server_name not in response[u"signatures"]):
raise KeyLookupError("Key response not signed by remote server")
if "tls_fingerprints" not in response:
raise KeyLookupError("Key response missing TLS fingerprints")
certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate
)
sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
response_sha256_fingerprints = set()
for fingerprint in response[u"tls_fingerprints"]:
if u"sha256" in fingerprint:
response_sha256_fingerprints.add(fingerprint[u"sha256"])
if sha256_fingerprint_b64 not in response_sha256_fingerprints:
raise KeyLookupError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response(
from_server=server_name,
requested_ids=[requested_key_id],
@ -657,78 +617,6 @@ class Keyring(object):
defer.returnValue(results)
@defer.inlineCallbacks
def get_server_verify_key_v1_direct(self, server_name, key_ids):
"""Finds a verification key for the server with one of the key ids.
Args:
server_name (str): The name of the server to fetch a key for.
keys_ids (list of str): The key_ids to check for.
"""
# Try to fetch the key from the remote server.
(response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_client_options_factory
)
# Check the response.
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate
)
if ("signatures" not in response
or server_name not in response["signatures"]):
raise KeyLookupError("Key response not signed by remote server")
if "tls_certificate" not in response:
raise KeyLookupError("Key response missing TLS certificate")
tls_certificate_b64 = response["tls_certificate"]
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
raise KeyLookupError("TLS certificate doesn't match")
# Cache the result in the datastore.
time_now_ms = self.clock.time_msec()
verify_keys = {}
for key_id, key_base64 in response["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.time_added = time_now_ms
verify_keys[key_id] = verify_key
for key_id in response["signatures"][server_name]:
if key_id not in response["verify_keys"]:
raise KeyLookupError(
"Key response must include verification keys for all"
" signatures"
)
if key_id in verify_keys:
verify_signed_json(
response,
server_name,
verify_keys[key_id]
)
yield self.store.store_server_certificate(
server_name,
server_name,
time_now_ms,
tls_certificate,
)
yield self.store_keys(
server_name=server_name,
from_server=server_name,
verify_keys=verify_keys,
)
defer.returnValue(verify_keys)
def store_keys(self, server_name, from_server, verify_keys):
"""Store a collection of verify keys for a given server
Args:
@ -768,7 +656,7 @@ def _handle_key_deferred(verify_request):
try:
with PreserveLoggingContext():
_, key_id, verify_key = yield verify_request.deferred
except IOError as e:
except (IOError, RequestSendFailed) as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e),

View file

@ -20,17 +20,25 @@ from signedjson.key import decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64
from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, JoinRules, Membership
from synapse.api.constants import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
EventTypes,
JoinRules,
Membership,
RoomVersions,
)
from synapse.api.errors import AuthError, EventSizeError, SynapseError
from synapse.types import UserID, get_domain_from_id
logger = logging.getLogger(__name__)
def check(event, auth_events, do_sig_check=True, do_size_check=True):
def check(room_version, event, auth_events, do_sig_check=True, do_size_check=True):
""" Checks if this event is correctly authed.
Args:
room_version (str): the version of the room
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
@ -48,7 +56,6 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
if do_sig_check:
sender_domain = get_domain_from_id(event.sender)
event_id_domain = get_domain_from_id(event.event_id)
is_invite_via_3pid = (
event.type == EventTypes.Member
@ -65,9 +72,13 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
if not is_invite_via_3pid:
raise AuthError(403, "Event not signed by sender's server")
# Check the event_id's domain has signed the event
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if event.format_version in (EventFormatVersions.V1,):
# Only older room versions have event IDs to check.
event_id_domain = get_domain_from_id(event.event_id)
# Check the origin domain has signed the event
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
@ -167,7 +178,7 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
_check_power_levels(event, auth_events)
if event.type == EventTypes.Redaction:
check_redaction(event, auth_events)
check_redaction(room_version, event, auth_events)
logger.debug("Allowing! %s", event)
@ -200,11 +211,11 @@ def _is_membership_change_allowed(event, auth_events):
membership = event.content["membership"]
# Check if this is the room creator joining:
if len(event.prev_events) == 1 and Membership.JOIN == membership:
if len(event.prev_event_ids()) == 1 and Membership.JOIN == membership:
# Get room creation event:
key = (EventTypes.Create, "", )
create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id:
if create and event.prev_event_ids()[0] == create.event_id:
if create.content["creator"] == event.state_key:
return
@ -421,7 +432,7 @@ def _can_send_event(event, auth_events):
return True
def check_redaction(event, auth_events):
def check_redaction(room_version, event, auth_events):
"""Check whether the event sender is allowed to redact the target event.
Returns:
@ -441,10 +452,16 @@ def check_redaction(event, auth_events):
if user_level >= redact_level:
return False
redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
if room_version in (RoomVersions.V1, RoomVersions.V2,):
redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
return True
elif room_version == RoomVersions.V3:
event.internal_metadata.recheck_redaction = True
return True
else:
raise RuntimeError("Unrecognized room version %r" % (room_version,))
raise AuthError(
403,

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,6 +19,9 @@ from distutils.util import strtobool
import six
from unpaddedbase64 import encode_base64
from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventFormatVersions, RoomVersions
from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
@ -41,8 +45,13 @@ class _EventInternalMetadata(object):
def is_outlier(self):
return getattr(self, "outlier", False)
def is_invite_from_remote(self):
return getattr(self, "invite_from_remote", False)
def is_out_of_band_membership(self):
"""Whether this is an out of band membership, like an invite or an invite
rejection. This is needed as those events are marked as outliers, but
they still need to be processed as if they're new events (e.g. updating
invite state in the database, relaying to clients, etc).
"""
return getattr(self, "out_of_band_membership", False)
def get_send_on_behalf_of(self):
"""Whether this server should send the event on behalf of another server.
@ -53,6 +62,21 @@ class _EventInternalMetadata(object):
"""
return getattr(self, "send_on_behalf_of", None)
def need_to_check_redaction(self):
"""Whether the redaction event needs to be rechecked when fetching
from the database.
Starting in room v3 redaction events are accepted up front, and later
checked to see if the redacter and redactee's domains match.
If the sender of the redaction event is allowed to redact any event
due to auth rules, then this will always return false.
Returns:
bool
"""
return getattr(self, "recheck_redaction", False)
def _event_dict_property(key):
# We want to be able to use hasattr with the event dict properties.
@ -159,8 +183,28 @@ class EventBase(object):
def keys(self):
return six.iterkeys(self._event_dict)
def prev_event_ids(self):
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
list[str]: The list of event IDs of this event's prev_events
"""
return [e for e, _ in self.prev_events]
def auth_event_ids(self):
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
list[str]: The list of event IDs of this event's auth_events
"""
return [e for e, _ in self.auth_events]
class FrozenEvent(EventBase):
format_version = EventFormatVersions.V1 # All events of this type are V1
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
event_dict = dict(event_dict)
@ -195,16 +239,6 @@ class FrozenEvent(EventBase):
rejected_reason=rejected_reason,
)
@staticmethod
def from_event(event):
e = FrozenEvent(
event.get_pdu_json()
)
e.internal_metadata = event.internal_metadata
return e
def __str__(self):
return self.__repr__()
@ -214,3 +248,127 @@ class FrozenEvent(EventBase):
self.get("type", None),
self.get("state_key", None),
)
class FrozenEventV2(EventBase):
format_version = EventFormatVersions.V2 # All events of this type are V2
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a
# copy.deepcopy
signatures = {
name: {sig_id: sig for sig_id, sig in sigs.items()}
for name, sigs in event_dict.pop("signatures", {}).items()
}
assert "event_id" not in event_dict
unsigned = dict(event_dict.pop("unsigned", {}))
# We intern these strings because they turn up a lot (especially when
# caching).
event_dict = intern_dict(event_dict)
if USE_FROZEN_DICTS:
frozen_dict = freeze(event_dict)
else:
frozen_dict = event_dict
self._event_id = None
self.type = event_dict["type"]
if "state_key" in event_dict:
self.state_key = event_dict["state_key"]
super(FrozenEventV2, self).__init__(
frozen_dict,
signatures=signatures,
unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict,
rejected_reason=rejected_reason,
)
@property
def event_id(self):
# We have to import this here as otherwise we get an import loop which
# is hard to break.
from synapse.crypto.event_signing import compute_event_reference_hash
if self._event_id:
return self._event_id
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
return self._event_id
def prev_event_ids(self):
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
list[str]: The list of event IDs of this event's prev_events
"""
return self.prev_events
def auth_event_ids(self):
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
list[str]: The list of event IDs of this event's auth_events
"""
return self.auth_events
def __str__(self):
return self.__repr__()
def __repr__(self):
return "<FrozenEventV2 event_id='%s', type='%s', state_key='%s'>" % (
self.event_id,
self.get("type", None),
self.get("state_key", None),
)
def room_version_to_event_format(room_version):
"""Converts a room version string to the event format
Args:
room_version (str)
Returns:
int
"""
if room_version not in KNOWN_ROOM_VERSIONS:
# We should have already checked version, so this should not happen
raise RuntimeError("Unrecognized room version %s" % (room_version,))
if room_version in (
RoomVersions.V1, RoomVersions.V2, RoomVersions.STATE_V2_TEST,
):
return EventFormatVersions.V1
elif room_version in (RoomVersions.V3,):
return EventFormatVersions.V2
else:
raise RuntimeError("Unrecognized room version %s" % (room_version,))
def event_type_from_format_version(format_version):
"""Returns the python type to use to construct an Event object for the
given event format version.
Args:
format_version (int): The event format version
Returns:
type: A type that can be initialized as per the initializer of
`FrozenEvent`
"""
if format_version == EventFormatVersions.V1:
return FrozenEvent
elif format_version == EventFormatVersions.V2:
return FrozenEventV2
else:
raise Exception(
"No event format %r" % (format_version,)
)

View file

@ -13,63 +13,270 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import attr
from twisted.internet import defer
from synapse.api.constants import (
KNOWN_EVENT_FORMAT_VERSIONS,
KNOWN_ROOM_VERSIONS,
MAX_DEPTH,
EventFormatVersions,
)
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.types import EventID
from synapse.util.stringutils import random_string
from . import EventBase, FrozenEvent, _event_dict_property
from . import (
_EventInternalMetadata,
event_type_from_format_version,
room_version_to_event_format,
)
class EventBuilder(EventBase):
def __init__(self, key_values={}, internal_metadata_dict={}):
signatures = copy.deepcopy(key_values.pop("signatures", {}))
unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
@attr.s(slots=True, cmp=False, frozen=True)
class EventBuilder(object):
"""A format independent event builder used to build up the event content
before signing the event.
super(EventBuilder, self).__init__(
key_values,
signatures=signatures,
unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict,
(Note that while objects of this class are frozen, the
content/unsigned/internal_metadata fields are still mutable)
Attributes:
format_version (int): Event format version
room_id (str)
type (str)
sender (str)
content (dict)
unsigned (dict)
internal_metadata (_EventInternalMetadata)
_state (StateHandler)
_auth (synapse.api.Auth)
_store (DataStore)
_clock (Clock)
_hostname (str): The hostname of the server creating the event
_signing_key: The signing key to use to sign the event as the server
"""
_state = attr.ib()
_auth = attr.ib()
_store = attr.ib()
_clock = attr.ib()
_hostname = attr.ib()
_signing_key = attr.ib()
format_version = attr.ib()
room_id = attr.ib()
type = attr.ib()
sender = attr.ib()
content = attr.ib(default=attr.Factory(dict))
unsigned = attr.ib(default=attr.Factory(dict))
# These only exist on a subset of events, so they raise AttributeError if
# someone tries to get them when they don't exist.
_state_key = attr.ib(default=None)
_redacts = attr.ib(default=None)
internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({})))
@property
def state_key(self):
if self._state_key is not None:
return self._state_key
raise AttributeError("state_key")
def is_state(self):
return self._state_key is not None
@defer.inlineCallbacks
def build(self, prev_event_ids):
"""Transform into a fully signed and hashed event
Args:
prev_event_ids (list[str]): The event IDs to use as the prev events
Returns:
Deferred[FrozenEvent]
"""
state_ids = yield self._state.get_current_state_ids(
self.room_id, prev_event_ids,
)
auth_ids = yield self._auth.compute_auth_events(
self, state_ids,
)
event_id = _event_dict_property("event_id")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
if self.format_version == EventFormatVersions.V1:
auth_events = yield self._store.add_event_hashes(auth_ids)
prev_events = yield self._store.add_event_hashes(prev_event_ids)
else:
auth_events = auth_ids
prev_events = prev_event_ids
def build(self):
return FrozenEvent.from_event(self)
old_depth = yield self._store.get_max_depth_of(
prev_event_ids,
)
depth = old_depth + 1
# we cap depth of generated events, to ensure that they are not
# rejected by other servers (and so that they can be persisted in
# the db)
depth = min(depth, MAX_DEPTH)
event_dict = {
"auth_events": auth_events,
"prev_events": prev_events,
"type": self.type,
"room_id": self.room_id,
"sender": self.sender,
"content": self.content,
"unsigned": self.unsigned,
"depth": depth,
"prev_state": [],
}
if self.is_state():
event_dict["state_key"] = self._state_key
if self._redacts is not None:
event_dict["redacts"] = self._redacts
defer.returnValue(
create_local_event_from_event_dict(
clock=self._clock,
hostname=self._hostname,
signing_key=self._signing_key,
format_version=self.format_version,
event_dict=event_dict,
internal_metadata_dict=self.internal_metadata.get_dict(),
)
)
class EventBuilderFactory(object):
def __init__(self, clock, hostname):
self.clock = clock
self.hostname = hostname
def __init__(self, hs):
self.clock = hs.get_clock()
self.hostname = hs.hostname
self.signing_key = hs.config.signing_key[0]
self.event_id_count = 0
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
def create_event_id(self):
i = str(self.event_id_count)
self.event_id_count += 1
def new(self, room_version, key_values):
"""Generate an event builder appropriate for the given room version
local_part = str(int(self.clock.time())) + i + random_string(5)
Args:
room_version (str): Version of the room that we're creating an
event builder for
key_values (dict): Fields used as the basis of the new event
e_id = EventID(local_part, self.hostname)
Returns:
EventBuilder
"""
return e_id.to_string()
# There's currently only the one event version defined
if room_version not in KNOWN_ROOM_VERSIONS:
raise Exception(
"No event format defined for version %r" % (room_version,)
)
def new(self, key_values={}):
key_values["event_id"] = self.create_event_id()
return EventBuilder(
store=self.store,
state=self.state,
auth=self.auth,
clock=self.clock,
hostname=self.hostname,
signing_key=self.signing_key,
format_version=room_version_to_event_format(room_version),
type=key_values["type"],
state_key=key_values.get("state_key"),
room_id=key_values["room_id"],
sender=key_values["sender"],
content=key_values.get("content", {}),
unsigned=key_values.get("unsigned", {}),
redacts=key_values.get("redacts", None),
)
time_now = int(self.clock.time_msec())
key_values.setdefault("origin", self.hostname)
key_values.setdefault("origin_server_ts", time_now)
def create_local_event_from_event_dict(clock, hostname, signing_key,
format_version, event_dict,
internal_metadata_dict=None):
"""Takes a fully formed event dict, ensuring that fields like `origin`
and `origin_server_ts` have correct values for a locally produced event,
then signs and hashes it.
key_values.setdefault("unsigned", {})
age = key_values["unsigned"].pop("age", 0)
key_values["unsigned"].setdefault("age_ts", time_now - age)
Args:
clock (Clock)
hostname (str)
signing_key
format_version (int)
event_dict (dict)
internal_metadata_dict (dict|None)
key_values["signatures"] = {}
Returns:
FrozenEvent
"""
return EventBuilder(key_values=key_values,)
# There's currently only the one event version defined
if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
raise Exception(
"No event format defined for version %r" % (format_version,)
)
if internal_metadata_dict is None:
internal_metadata_dict = {}
time_now = int(clock.time_msec())
if format_version == EventFormatVersions.V1:
event_dict["event_id"] = _create_event_id(clock, hostname)
event_dict["origin"] = hostname
event_dict["origin_server_ts"] = time_now
event_dict.setdefault("unsigned", {})
age = event_dict["unsigned"].pop("age", 0)
event_dict["unsigned"].setdefault("age_ts", time_now - age)
event_dict.setdefault("signatures", {})
add_hashes_and_signatures(
event_dict,
hostname,
signing_key,
)
return event_type_from_format_version(format_version)(
event_dict, internal_metadata_dict=internal_metadata_dict,
)
# A counter used when generating new event IDs
_event_id_counter = 0
def _create_event_id(clock, hostname):
"""Create a new event ID
Args:
clock (Clock)
hostname (str): The server name for the event ID
Returns:
str
"""
global _event_id_counter
i = str(_event_id_counter)
_event_id_counter += 1
local_part = str(int(clock.time())) + i + random_string(5)
e_id = EventID(local_part, hostname)
return e_id.to_string()

View file

@ -38,8 +38,31 @@ def prune_event(event):
This is used when we "redact" an event. We want to remove all fields that
the user has specified, but we do want to keep necessary information like
type, state_key etc.
Args:
event (FrozenEvent)
Returns:
FrozenEvent
"""
pruned_event_dict = prune_event_dict(event.get_dict())
from . import event_type_from_format_version
return event_type_from_format_version(event.format_version)(
pruned_event_dict, event.internal_metadata.get_dict()
)
def prune_event_dict(event_dict):
"""Redacts the event_dict in the same way as `prune_event`, except it
operates on dicts rather than event objects
Args:
event_dict (dict)
Returns:
dict: A copy of the pruned event dict
"""
event_type = event.type
allowed_keys = [
"event_id",
@ -59,13 +82,13 @@ def prune_event(event):
"membership",
]
event_dict = event.get_dict()
event_type = event_dict["type"]
new_content = {}
def add_fields(*fields):
for field in fields:
if field in event.content:
if field in event_dict["content"]:
new_content[field] = event_dict["content"][field]
if event_type == EventTypes.Member:
@ -98,17 +121,17 @@ def prune_event(event):
allowed_fields["content"] = new_content
allowed_fields["unsigned"] = {}
unsigned = {}
allowed_fields["unsigned"] = unsigned
if "age_ts" in event.unsigned:
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
if "replaces_state" in event.unsigned:
allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"]
event_unsigned = event_dict.get("unsigned", {})
return type(event)(
allowed_fields,
internal_metadata_dict=event.internal_metadata.get_dict()
)
if "age_ts" in event_unsigned:
unsigned["age_ts"] = event_unsigned["age_ts"]
if "replaces_state" in event_unsigned:
unsigned["replaces_state"] = event_unsigned["replaces_state"]
return allowed_fields
def _copy_field(src, dst, field):
@ -244,6 +267,7 @@ def serialize_event(e, time_now_ms, as_client_event=True,
Returns:
dict
"""
# FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase):
return e
@ -253,6 +277,8 @@ def serialize_event(e, time_now_ms, as_client_event=True,
# Should this strip out None's?
d = {k: v for k, v in e.get_dict().items()}
d["event_id"] = e.event_id
if "age_ts" in d["unsigned"]:
d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"]
del d["unsigned"]["age_ts"]

View file

@ -15,23 +15,29 @@
from six import string_types
from synapse.api.constants import EventTypes, Membership
from synapse.api.constants import EventFormatVersions, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.types import EventID, RoomID, UserID
class EventValidator(object):
def validate_new(self, event):
"""Validates the event has roughly the right format
def validate(self, event):
EventID.from_string(event.event_id)
RoomID.from_string(event.room_id)
Args:
event (FrozenEvent)
"""
self.validate_builder(event)
if event.format_version == EventFormatVersions.V1:
EventID.from_string(event.event_id)
required = [
# "auth_events",
"auth_events",
"content",
# "hashes",
"hashes",
"origin",
# "prev_events",
"prev_events",
"sender",
"type",
]
@ -41,8 +47,25 @@ class EventValidator(object):
raise SynapseError(400, "Event does not have key %s" % (k,))
# Check that the following keys have string values
strings = [
event_strings = [
"origin",
]
for s in event_strings:
if not isinstance(getattr(event, s), string_types):
raise SynapseError(400, "'%s' not a string type" % (s,))
def validate_builder(self, event):
"""Validates that the builder/event has roughly the right format. Only
checks values that we expect a proto event to have, rather than all the
fields an event would have
Args:
event (EventBuilder|FrozenEvent)
"""
strings = [
"room_id",
"sender",
"type",
]
@ -54,22 +77,7 @@ class EventValidator(object):
if not isinstance(getattr(event, s), string_types):
raise SynapseError(400, "Not '%s' a string type" % (s,))
if event.type == EventTypes.Member:
if "membership" not in event.content:
raise SynapseError(400, "Content has not membership key")
if event.content["membership"] not in Membership.LIST:
raise SynapseError(400, "Invalid membership key")
# Check that the following keys have dictionary values
# TODO
# Check that the following keys have the correct format for DAGs
# TODO
def validate_new(self, event):
self.validate(event)
RoomID.from_string(event.room_id)
UserID.from_string(event.sender)
if event.type == EventTypes.Message:
@ -86,9 +94,16 @@ class EventValidator(object):
elif event.type == EventTypes.Name:
self._ensure_strings(event.content, ["name"])
elif event.type == EventTypes.Member:
if "membership" not in event.content:
raise SynapseError(400, "Content has not membership key")
if event.content["membership"] not in Membership.LIST:
raise SynapseError(400, "Invalid membership key")
def _ensure_strings(self, d, keys):
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
if not isinstance(d[s], string_types):
raise SynapseError(400, "Not '%s' a string type" % (s,))
raise SynapseError(400, "'%s' not a string type" % (s,))

View file

@ -20,10 +20,10 @@ import six
from twisted.internet import defer
from twisted.internet.defer import DeferredList
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership, RoomVersions
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import FrozenEvent
from synapse.events import event_type_from_format_version
from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict
from synapse.types import get_domain_from_id
@ -43,8 +43,8 @@ class FederationBase(object):
self._clock = hs.get_clock()
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
include_none=False):
def _check_sigs_and_hash_and_fetch(self, origin, pdus, room_version,
outlier=False, include_none=False):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
@ -56,13 +56,17 @@ class FederationBase(object):
a new list.
Args:
origin (str)
pdu (list)
outlier (bool)
room_version (str)
outlier (bool): Whether the events are outliers or not
include_none (str): Whether to include None in the returned list
for events that have failed their checks
Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
"""
deferreds = self._check_sigs_and_hashes(pdus)
deferreds = self._check_sigs_and_hashes(room_version, pdus)
@defer.inlineCallbacks
def handle_check_result(pdu, deferred):
@ -84,6 +88,7 @@ class FederationBase(object):
res = yield self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
room_version=room_version,
outlier=outlier,
timeout=10000,
)
@ -116,16 +121,17 @@ class FederationBase(object):
else:
defer.returnValue([p for p in valid_pdus if p])
def _check_sigs_and_hash(self, pdu):
def _check_sigs_and_hash(self, room_version, pdu):
return logcontext.make_deferred_yieldable(
self._check_sigs_and_hashes([pdu])[0],
self._check_sigs_and_hashes(room_version, [pdu])[0],
)
def _check_sigs_and_hashes(self, pdus):
def _check_sigs_and_hashes(self, room_version, pdus):
"""Checks that each of the received events is correctly signed by the
sending server.
Args:
room_version (str): The room version of the PDUs
pdus (list[FrozenEvent]): the events to be checked
Returns:
@ -136,7 +142,7 @@ class FederationBase(object):
* throws a SynapseError if the signature check failed.
The deferreds run their callbacks in the sentinel logcontext.
"""
deferreds = _check_sigs_on_pdus(self.keyring, pdus)
deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
ctx = logcontext.LoggingContext.current_context()
@ -198,16 +204,17 @@ class FederationBase(object):
class PduToCheckSig(namedtuple("PduToCheckSig", [
"pdu", "redacted_pdu_json", "event_id_domain", "sender_domain", "deferreds",
"pdu", "redacted_pdu_json", "sender_domain", "deferreds",
])):
pass
def _check_sigs_on_pdus(keyring, pdus):
def _check_sigs_on_pdus(keyring, room_version, pdus):
"""Check that the given events are correctly signed
Args:
keyring (synapse.crypto.Keyring): keyring object to do the checks
room_version (str): the room version of the PDUs
pdus (Collection[EventBase]): the events to be checked
Returns:
@ -220,9 +227,7 @@ def _check_sigs_on_pdus(keyring, pdus):
# we want to check that the event is signed by:
#
# (a) the server which created the event_id
#
# (b) the sender's server.
# (a) the sender's server
#
# - except in the case of invites created from a 3pid invite, which are exempt
# from this check, because the sender has to match that of the original 3pid
@ -236,34 +241,26 @@ def _check_sigs_on_pdus(keyring, pdus):
# and signatures are *supposed* to be valid whether or not an event has been
# redacted. But this isn't the worst of the ways that 3pid invites are broken.
#
# (b) for V1 and V2 rooms, the server which created the event_id
#
# let's start by getting the domain for each pdu, and flattening the event back
# to JSON.
pdus_to_check = [
PduToCheckSig(
pdu=p,
redacted_pdu_json=prune_event(p).get_pdu_json(),
event_id_domain=get_domain_from_id(p.event_id),
sender_domain=get_domain_from_id(p.sender),
deferreds=[],
)
for p in pdus
]
# first make sure that the event is signed by the event_id's domain
deferreds = keyring.verify_json_objects_for_server([
(p.event_id_domain, p.redacted_pdu_json)
for p in pdus_to_check
])
for p, d in zip(pdus_to_check, deferreds):
p.deferreds.append(d)
# now let's look for events where the sender's domain is different to the
# event id's domain (normally only the case for joins/leaves), and add additional
# checks.
# First we check that the sender event is signed by the sender's domain
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [
p for p in pdus_to_check
if p.sender_domain != p.event_id_domain and not _is_invite_via_3pid(p.pdu)
if not _is_invite_via_3pid(p.pdu)
]
more_deferreds = keyring.verify_json_objects_for_server([
@ -274,19 +271,43 @@ def _check_sigs_on_pdus(keyring, pdus):
for p, d in zip(pdus_to_check_sender, more_deferreds):
p.deferreds.append(d)
# now let's look for events where the sender's domain is different to the
# event id's domain (normally only the case for joins/leaves), and add additional
# checks. Only do this if the room version has a concept of event ID domain
if room_version in (
RoomVersions.V1, RoomVersions.V2, RoomVersions.STATE_V2_TEST,
):
pdus_to_check_event_id = [
p for p in pdus_to_check
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
]
more_deferreds = keyring.verify_json_objects_for_server([
(get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json)
for p in pdus_to_check_event_id
])
for p, d in zip(pdus_to_check_event_id, more_deferreds):
p.deferreds.append(d)
elif room_version in (RoomVersions.V3,):
pass # No further checks needed, as event IDs are hashes here
else:
raise RuntimeError("Unrecognized room version %s" % (room_version,))
# replace lists of deferreds with single Deferreds
return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
def _flatten_deferred_list(deferreds):
"""Given a list of one or more deferreds, either return the single deferred, or
combine into a DeferredList.
"""Given a list of deferreds, either return the single deferred,
combine into a DeferredList, or return an already resolved deferred.
"""
if len(deferreds) > 1:
return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True)
else:
assert len(deferreds) == 1
elif len(deferreds) == 1:
return deferreds[0]
else:
return defer.succeed(None)
def _is_invite_via_3pid(event):
@ -297,11 +318,12 @@ def _is_invite_via_3pid(event):
)
def event_from_pdu_json(pdu_json, outlier=False):
def event_from_pdu_json(pdu_json, event_format_version, outlier=False):
"""Construct a FrozenEvent from an event json received over federation
Args:
pdu_json (object): pdu as received over federation
event_format_version (int): The event format version
outlier (bool): True to mark this event as an outlier
Returns:
@ -313,7 +335,7 @@ def event_from_pdu_json(pdu_json, outlier=False):
"""
# we could probably enforce a bunch of other fields here (room_id, sender,
# origin, etc etc)
assert_params_in_dict(pdu_json, ('event_id', 'type', 'depth'))
assert_params_in_dict(pdu_json, ('type', 'depth'))
depth = pdu_json['depth']
if not isinstance(depth, six.integer_types):
@ -325,8 +347,8 @@ def event_from_pdu_json(pdu_json, outlier=False):
elif depth > MAX_DEPTH:
raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
event = FrozenEvent(
pdu_json
event = event_type_from_format_version(event_format_version)(
pdu_json,
)
event.internal_metadata.outlier = outlier

View file

@ -25,14 +25,19 @@ from prometheus_client import Counter
from twisted.internet import defer
from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, Membership
from synapse.api.constants import (
KNOWN_ROOM_VERSIONS,
EventTypes,
Membership,
RoomVersions,
)
from synapse.api.errors import (
CodeMessageException,
FederationDeniedError,
HttpResponseException,
SynapseError,
)
from synapse.events import builder
from synapse.events import builder, room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.util import logcontext, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
@ -66,6 +71,9 @@ class FederationClient(FederationBase):
self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
self.hostname = hs.hostname
self.signing_key = hs.config.signing_key[0]
self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache",
clock=self._clock,
@ -162,13 +170,13 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
@log_function
def backfill(self, dest, context, limit, extremities):
def backfill(self, dest, room_id, limit, extremities):
"""Requests some more historic PDUs for the given context from the
given destination server.
Args:
dest (str): The remote home server to ask.
context (str): The context to backfill.
room_id (str): The room_id to backfill.
limit (int): The maximum number of PDUs to return.
extremities (list): List of PDU id and origins of the first pdus
we have seen from the context
@ -183,18 +191,21 @@ class FederationClient(FederationBase):
return
transaction_data = yield self.transport_layer.backfill(
dest, context, extremities, limit)
dest, room_id, extremities, limit)
logger.debug("backfill transaction_data=%s", repr(transaction_data))
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
pdus = [
event_from_pdu_json(p, outlier=False)
event_from_pdu_json(p, format_ver, outlier=False)
for p in transaction_data["pdus"]
]
# FIXME: We should handle signature failures more gracefully.
pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
self._check_sigs_and_hashes(pdus),
self._check_sigs_and_hashes(room_version, pdus),
consumeErrors=True,
).addErrback(unwrapFirstError))
@ -202,7 +213,8 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
@log_function
def get_pdu(self, destinations, event_id, outlier=False, timeout=None):
def get_pdu(self, destinations, event_id, room_version, outlier=False,
timeout=None):
"""Requests the PDU with given origin and ID from the remote home
servers.
@ -212,6 +224,7 @@ class FederationClient(FederationBase):
Args:
destinations (list): Which home servers to query
event_id (str): event to fetch
room_version (str): version of the room
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
@ -230,6 +243,8 @@ class FederationClient(FederationBase):
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
format_ver = room_version_to_event_format(room_version)
signed_pdu = None
for destination in destinations:
now = self._clock.time_msec()
@ -245,7 +260,7 @@ class FederationClient(FederationBase):
logger.debug("transaction_data %r", transaction_data)
pdu_list = [
event_from_pdu_json(p, outlier=outlier)
event_from_pdu_json(p, format_ver, outlier=outlier)
for p in transaction_data["pdus"]
]
@ -253,7 +268,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
signed_pdu = yield self._check_sigs_and_hash(pdu)
signed_pdu = yield self._check_sigs_and_hash(room_version, pdu)
break
@ -339,12 +354,16 @@ class FederationClient(FederationBase):
destination, room_id, event_id=event_id,
)
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
pdus = [
event_from_pdu_json(p, outlier=True) for p in result["pdus"]
event_from_pdu_json(p, format_ver, outlier=True)
for p in result["pdus"]
]
auth_chain = [
event_from_pdu_json(p, outlier=True)
event_from_pdu_json(p, format_ver, outlier=True)
for p in result.get("auth_chain", [])
]
@ -355,7 +374,8 @@ class FederationClient(FederationBase):
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
destination,
[p for p in pdus if p.event_id not in seen_events],
outlier=True
outlier=True,
room_version=room_version,
)
signed_pdus.extend(
seen_events[p.event_id] for p in pdus if p.event_id in seen_events
@ -364,7 +384,8 @@ class FederationClient(FederationBase):
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination,
[p for p in auth_chain if p.event_id not in seen_events],
outlier=True
outlier=True,
room_version=room_version,
)
signed_auth.extend(
seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
@ -411,6 +432,8 @@ class FederationClient(FederationBase):
random.shuffle(srvs)
return srvs
room_version = yield self.store.get_room_version(room_id)
batch_size = 20
missing_events = list(missing_events)
for i in range(0, len(missing_events), batch_size):
@ -421,6 +444,7 @@ class FederationClient(FederationBase):
self.get_pdu,
destinations=random_server_list(),
event_id=e_id,
room_version=room_version,
)
for e_id in batch
]
@ -445,13 +469,17 @@ class FederationClient(FederationBase):
destination, room_id, event_id,
)
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
auth_chain = [
event_from_pdu_json(p, outlier=True)
event_from_pdu_json(p, format_ver, outlier=True)
for p in res["auth_chain"]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
destination, auth_chain,
outlier=True, room_version=room_version,
)
signed_auth.sort(key=lambda e: e.depth)
@ -522,6 +550,8 @@ class FederationClient(FederationBase):
Does so by asking one of the already participating servers to create an
event with proper context.
Returns a fully signed and hashed event.
Note that this does not append any events to any graphs.
Args:
@ -536,8 +566,10 @@ class FederationClient(FederationBase):
params (dict[str, str|Iterable[str]]): Query parameters to include in the
request.
Return:
Deferred: resolves to a tuple of (origin (str), event (object))
where origin is the remote homeserver which generated the event.
Deferred[tuple[str, FrozenEvent, int]]: resolves to a tuple of
`(origin, event, event_format)` where origin is the remote
homeserver which generated the event, and event_format is one of
`synapse.api.constants.EventFormatVersions`.
Fails with a ``SynapseError`` if the chosen remote server
returns a 300/400 code.
@ -557,6 +589,11 @@ class FederationClient(FederationBase):
destination, room_id, user_id, membership, params,
)
# Note: If not supplied, the room version may be either v1 or v2,
# however either way the event format version will be v1.
room_version = ret.get("room_version", RoomVersions.V1)
event_format = room_version_to_event_format(room_version)
pdu_dict = ret.get("event", None)
if not isinstance(pdu_dict, dict):
raise InvalidResponseError("Bad 'event' field in response")
@ -571,17 +608,20 @@ class FederationClient(FederationBase):
if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = []
ev = builder.EventBuilder(pdu_dict)
ev = builder.create_local_event_from_event_dict(
self._clock, self.hostname, self.signing_key,
format_version=event_format, event_dict=pdu_dict,
)
defer.returnValue(
(destination, ev)
(destination, ev, event_format)
)
return self._try_destination_list(
"make_" + membership, destinations, send_request,
)
def send_join(self, destinations, pdu):
def send_join(self, destinations, pdu, event_format_version):
"""Sends a join event to one of a list of homeservers.
Doing so will cause the remote server to add the event to the graph,
@ -591,6 +631,7 @@ class FederationClient(FederationBase):
destinations (str): Candidate homeservers which are probably
participating in the room.
pdu (BaseEvent): event to be sent
event_format_version (int): The event format version
Return:
Deferred: resolves to a dict with members ``origin`` (a string
@ -636,12 +677,12 @@ class FederationClient(FederationBase):
logger.debug("Got content: %s", content)
state = [
event_from_pdu_json(p, outlier=True)
event_from_pdu_json(p, event_format_version, outlier=True)
for p in content.get("state", [])
]
auth_chain = [
event_from_pdu_json(p, outlier=True)
event_from_pdu_json(p, event_format_version, outlier=True)
for p in content.get("auth_chain", [])
]
@ -650,9 +691,21 @@ class FederationClient(FederationBase):
for p in itertools.chain(state, auth_chain)
}
room_version = None
for e in state:
if (e.type, e.state_key) == (EventTypes.Create, ""):
room_version = e.content.get("room_version", RoomVersions.V1)
break
if room_version is None:
# If the state doesn't have a create event then the room is
# invalid, and it would fail auth checks anyway.
raise SynapseError(400, "No create event in state")
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, list(pdus.values()),
outlier=True,
room_version=room_version,
)
valid_pdus_map = {
@ -690,32 +743,75 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu):
time_now = self._clock.time_msec()
try:
code, content = yield self.transport_layer.send_invite(
destination=destination,
room_id=room_id,
event_id=event_id,
content=pdu.get_pdu_json(time_now),
)
except HttpResponseException as e:
if e.code == 403:
raise e.to_synapse_error()
raise
room_version = yield self.store.get_room_version(room_id)
content = yield self._do_send_invite(destination, pdu, room_version)
pdu_dict = content["event"]
logger.debug("Got response to send_invite: %s", pdu_dict)
pdu = event_from_pdu_json(pdu_dict)
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
pdu = event_from_pdu_json(pdu_dict, format_ver)
# Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu)
pdu = yield self._check_sigs_and_hash(room_version, pdu)
# FIXME: We should handle signature failures more gracefully.
defer.returnValue(pdu)
@defer.inlineCallbacks
def _do_send_invite(self, destination, pdu, room_version):
"""Actually sends the invite, first trying v2 API and falling back to
v1 API if necessary.
Args:
destination (str): Target server
pdu (FrozenEvent)
room_version (str)
Returns:
dict: The event as a dict as returned by the remote server
"""
time_now = self._clock.time_msec()
try:
content = yield self.transport_layer.send_invite_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content={
"event": pdu.get_pdu_json(time_now),
"room_version": room_version,
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
},
)
defer.returnValue(content)
except HttpResponseException as e:
if e.code in [400, 404]:
if room_version in (RoomVersions.V1, RoomVersions.V2):
pass # We'll fall through
else:
raise Exception("Remote server is too old")
elif e.code == 403:
raise e.to_synapse_error()
else:
raise
# Didn't work, try v1 API.
# Note the v1 API returns a tuple of `(200, content)`
_, content = yield self.transport_layer.send_invite_v1(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
defer.returnValue(content)
def send_leave(self, destinations, pdu):
"""Sends a leave event to one of a list of homeservers.
@ -785,13 +881,16 @@ class FederationClient(FederationBase):
content=send_content,
)
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
auth_chain = [
event_from_pdu_json(e)
event_from_pdu_json(e, format_ver)
for e in content["auth_chain"]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
destination, auth_chain, outlier=True, room_version=room_version,
)
signed_auth.sort(key=lambda e: e.depth)
@ -833,13 +932,16 @@ class FederationClient(FederationBase):
timeout=timeout,
)
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
events = [
event_from_pdu_json(e)
event_from_pdu_json(e, format_ver)
for e in content.get("events", [])
]
signed_events = yield self._check_sigs_and_hash_and_fetch(
destination, events, outlier=False
destination, events, outlier=False, room_version=room_version,
)
except HttpResponseException as e:
if not e.code == 400:

View file

@ -25,7 +25,7 @@ from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
from synapse.api.constants import EventTypes
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
FederationError,
@ -34,6 +34,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.crypto.event_signing import compute_event_signature
from synapse.events import room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
@ -147,6 +148,22 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id)
# Reject if PDU count > 50 and EDU count > 100
if (len(transaction.pdus) > 50
or (hasattr(transaction, "edus") and len(transaction.edus) > 100)):
logger.info(
"Transaction PDU or EDU count too large. Returning 400",
)
response = {}
yield self.transaction_actions.set_response(
origin,
transaction,
400, response
)
defer.returnValue((400, response))
received_pdus_counter.inc(len(transaction.pdus))
origin_host, _ = parse_server_name(origin)
@ -162,8 +179,29 @@ class FederationServer(FederationBase):
p["age_ts"] = request_time - int(p["age"])
del p["age"]
event = event_from_pdu_json(p)
room_id = event.room_id
# We try and pull out an event ID so that if later checks fail we
# can log something sensible. We don't mandate an event ID here in
# case future event formats get rid of the key.
possible_event_id = p.get("event_id", "<Unknown>")
# Now we get the room ID so that we can check that we know the
# version of the room.
room_id = p.get("room_id")
if not room_id:
logger.info(
"Ignoring PDU as does not have a room_id. Event ID: %s",
possible_event_id,
)
continue
try:
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
except NotFoundError:
logger.info("Ignoring PDU for unknown room_id: %s", room_id)
continue
event = event_from_pdu_json(p, format_ver)
pdus_by_room.setdefault(room_id, []).append(event)
pdu_results = {}
@ -300,7 +338,7 @@ class FederationServer(FederationBase):
if self.hs.is_mine_id(event.event_id):
event.signatures.update(
compute_event_signature(
event,
event.get_pdu_json(),
self.hs.hostname,
self.hs.config.signing_key[0]
)
@ -323,11 +361,6 @@ class FederationServer(FederationBase):
else:
defer.returnValue((404, ""))
@defer.inlineCallbacks
@log_function
def on_pull_request(self, origin, versions):
raise NotImplementedError("Pull transactions not implemented")
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
received_queries_counter.labels(query_type).inc()
@ -352,18 +385,23 @@ class FederationServer(FederationBase):
})
@defer.inlineCallbacks
def on_invite_request(self, origin, content):
pdu = event_from_pdu_json(content)
def on_invite_request(self, origin, content, room_version):
format_ver = room_version_to_event_format(room_version)
pdu = event_from_pdu_json(content, format_ver)
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, pdu.room_id)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
defer.returnValue({"event": ret_pdu.get_pdu_json(time_now)})
@defer.inlineCallbacks
def on_send_join_request(self, origin, content):
def on_send_join_request(self, origin, content, room_id):
logger.debug("on_send_join_request: content: %s", content)
pdu = event_from_pdu_json(content)
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
pdu = event_from_pdu_json(content, format_ver)
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, pdu.room_id)
@ -383,13 +421,22 @@ class FederationServer(FederationBase):
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id)
pdu = yield self.handler.on_make_leave_request(room_id, user_id)
room_version = yield self.store.get_room_version(room_id)
time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
defer.returnValue({
"event": pdu.get_pdu_json(time_now),
"room_version": room_version,
})
@defer.inlineCallbacks
def on_send_leave_request(self, origin, content):
def on_send_leave_request(self, origin, content, room_id):
logger.debug("on_send_leave_request: content: %s", content)
pdu = event_from_pdu_json(content)
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
pdu = event_from_pdu_json(content, format_ver)
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, pdu.room_id)
@ -435,13 +482,16 @@ class FederationServer(FederationBase):
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id)
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
auth_chain = [
event_from_pdu_json(e)
event_from_pdu_json(e, format_ver)
for e in content["auth_chain"]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
origin, auth_chain, outlier=True
origin, auth_chain, outlier=True, room_version=room_version,
)
ret = yield self.handler.on_query_auth(
@ -586,16 +636,19 @@ class FederationServer(FederationBase):
"""
# check that it's actually being sent from a valid destination to
# workaround bug #1753 in 0.18.5 and 0.18.6
if origin != get_domain_from_id(pdu.event_id):
if origin != get_domain_from_id(pdu.sender):
# We continue to accept join events from any server; this is
# necessary for the federation join dance to work correctly.
# (When we join over federation, the "helper" server is
# responsible for sending out the join event, rather than the
# origin. See bug #1893).
# origin. See bug #1893. This is also true for some third party
# invites).
if not (
pdu.type == 'm.room.member' and
pdu.content and
pdu.content.get("membership", None) == 'join'
pdu.content.get("membership", None) in (
Membership.JOIN, Membership.INVITE,
)
):
logger.info(
"Discarding PDU %s from invalid origin %s",
@ -608,9 +661,12 @@ class FederationServer(FederationBase):
pdu.event_id, origin
)
# We've already checked that we know the room version by this point
room_version = yield self.store.get_room_version(pdu.room_id)
# Check signature.
try:
pdu = yield self._check_sigs_and_hash(pdu)
pdu = yield self._check_sigs_and_hash(room_version, pdu)
except SynapseError as e:
raise FederationError(
"ERROR",

View file

@ -22,14 +22,17 @@ from prometheus_client import Counter
from twisted.internet import defer
import synapse.metrics
from synapse.api.errors import FederationDeniedError, HttpResponseException
from synapse.api.errors import (
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
)
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
from synapse.metrics import (
LaterGauge,
event_processing_loop_counter,
event_processing_loop_room_count,
events_processed_counter,
sent_edus_counter,
sent_transactions_counter,
)
from synapse.metrics.background_process_metrics import run_as_background_process
@ -43,10 +46,24 @@ from .units import Edu, Transaction
logger = logging.getLogger(__name__)
sent_pdus_destination_dist_count = Counter(
"synapse_federation_client_sent_pdu_destinations:count", ""
"synapse_federation_client_sent_pdu_destinations:count",
"Number of PDUs queued for sending to one or more destinations",
)
sent_pdus_destination_dist_total = Counter(
"synapse_federation_client_sent_pdu_destinations:total", ""
"Total number of PDUs queued for sending across all destinations",
)
sent_edus_counter = Counter(
"synapse_federation_client_sent_edus",
"Total number of EDUs successfully sent",
)
sent_edus_by_type = Counter(
"synapse_federation_client_sent_edus_by_type",
"Number of sent EDUs successfully sent, by event type",
["type"],
)
@ -171,7 +188,7 @@ class TransactionQueue(object):
def handle_event(event):
# Only send events for this server.
send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
is_mine = self.is_mine_id(event.event_id)
is_mine = self.is_mine_id(event.sender)
if not is_mine and send_on_behalf_of is None:
return
@ -183,9 +200,7 @@ class TransactionQueue(object):
# banned then it won't receive the event because it won't
# be in the room after the ban.
destinations = yield self.state.get_current_hosts_in_room(
event.room_id, latest_event_ids=[
prev_id for prev_id, _ in event.prev_events
],
event.room_id, latest_event_ids=event.prev_event_ids(),
)
except Exception:
logger.exception(
@ -358,8 +373,6 @@ class TransactionQueue(object):
logger.info("Not sending EDU to ourselves")
return
sent_edus_counter.inc()
if key:
self.pending_edus_keyed_by_dest.setdefault(
destination, {}
@ -494,6 +507,9 @@ class TransactionQueue(object):
)
if success:
sent_transactions_counter.inc()
sent_edus_counter.inc(len(pending_edus))
for edu in pending_edus:
sent_edus_by_type.labels(edu.edu_type).inc()
# Remove the acknowledged device messages from the database
# Only bother if we actually sent some device messages
if device_message_edus:
@ -520,11 +536,21 @@ class TransactionQueue(object):
)
except FederationDeniedError as e:
logger.info(e)
except Exception as e:
logger.warn(
"TX [%s] Failed to send transaction: %s",
except HttpResponseException as e:
logger.warning(
"TX [%s] Received %d response to transaction: %s",
destination, e.code, e,
)
except RequestSendFailed as e:
logger.warning("TX [%s] Failed to send transaction: %s", destination, e)
for p, _ in pending_pdus:
logger.info("Failed to send event %s to %s", p.event_id,
destination)
except Exception:
logger.exception(
"TX [%s] Failed to send transaction",
destination,
e,
)
for p, _ in pending_pdus:
logger.info("Failed to send event %s to %s", p.event_id,

View file

@ -21,7 +21,7 @@ from six.moves import urllib
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
from synapse.util.logutils import log_function
logger = logging.getLogger(__name__)
@ -51,7 +51,7 @@ class TransportLayerClient(object):
logger.debug("get_room_state dest=%s, room=%s",
destination, room_id)
path = _create_path(PREFIX, "/state/%s/", room_id)
path = _create_v1_path("/state/%s/", room_id)
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@ -73,7 +73,7 @@ class TransportLayerClient(object):
logger.debug("get_room_state_ids dest=%s, room=%s",
destination, room_id)
path = _create_path(PREFIX, "/state_ids/%s/", room_id)
path = _create_v1_path("/state_ids/%s/", room_id)
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@ -95,7 +95,7 @@ class TransportLayerClient(object):
logger.debug("get_pdu dest=%s, event_id=%s",
destination, event_id)
path = _create_path(PREFIX, "/event/%s/", event_id)
path = _create_v1_path("/event/%s/", event_id)
return self.client.get_json(destination, path=path, timeout=timeout)
@log_function
@ -121,7 +121,7 @@ class TransportLayerClient(object):
# TODO: raise?
return
path = _create_path(PREFIX, "/backfill/%s/", room_id)
path = _create_v1_path("/backfill/%s/", room_id)
args = {
"v": event_tuples,
@ -167,7 +167,7 @@ class TransportLayerClient(object):
# generated by the json_data_callback.
json_data = transaction.get_dict()
path = _create_path(PREFIX, "/send/%s/", transaction.transaction_id)
path = _create_v1_path("/send/%s/", transaction.transaction_id)
response = yield self.client.put_json(
transaction.destination,
@ -184,7 +184,7 @@ class TransportLayerClient(object):
@log_function
def make_query(self, destination, query_type, args, retry_on_dns_fail,
ignore_backoff=False):
path = _create_path(PREFIX, "/query/%s", query_type)
path = _create_v1_path("/query/%s", query_type)
content = yield self.client.get_json(
destination=destination,
@ -231,7 +231,7 @@ class TransportLayerClient(object):
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
)
path = _create_path(PREFIX, "/make_%s/%s/%s", membership, room_id, user_id)
path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id)
ignore_backoff = False
retry_on_dns_fail = False
@ -258,7 +258,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_join(self, destination, room_id, event_id, content):
path = _create_path(PREFIX, "/send_join/%s/%s", room_id, event_id)
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@ -271,7 +271,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_leave(self, destination, room_id, event_id, content):
path = _create_path(PREFIX, "/send_leave/%s/%s", room_id, event_id)
path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@ -289,8 +289,22 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_invite(self, destination, room_id, event_id, content):
path = _create_path(PREFIX, "/invite/%s/%s", room_id, event_id)
def send_invite_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
path=path,
data=content,
ignore_backoff=True,
)
defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def send_invite_v2(self, destination, room_id, event_id, content):
path = _create_v2_path("/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@ -306,7 +320,7 @@ class TransportLayerClient(object):
def get_public_rooms(self, remote_server, limit, since_token,
search_filter=None, include_all_networks=False,
third_party_instance_id=None):
path = PREFIX + "/publicRooms"
path = _create_v1_path("/publicRooms")
args = {
"include_all_networks": "true" if include_all_networks else "false",
@ -332,7 +346,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, destination, room_id, event_dict):
path = _create_path(PREFIX, "/exchange_third_party_invite/%s", room_id,)
path = _create_v1_path("/exchange_third_party_invite/%s", room_id,)
response = yield self.client.put_json(
destination=destination,
@ -345,7 +359,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
path = _create_path(PREFIX, "/event_auth/%s/%s", room_id, event_id)
path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
content = yield self.client.get_json(
destination=destination,
@ -357,7 +371,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_query_auth(self, destination, room_id, event_id, content):
path = _create_path(PREFIX, "/query_auth/%s/%s", room_id, event_id)
path = _create_v1_path("/query_auth/%s/%s", room_id, event_id)
content = yield self.client.post_json(
destination=destination,
@ -392,7 +406,7 @@ class TransportLayerClient(object):
Returns:
A dict containg the device keys.
"""
path = PREFIX + "/user/keys/query"
path = _create_v1_path("/user/keys/query")
content = yield self.client.post_json(
destination=destination,
@ -419,7 +433,7 @@ class TransportLayerClient(object):
Returns:
A dict containg the device keys.
"""
path = _create_path(PREFIX, "/user/devices/%s", user_id)
path = _create_v1_path("/user/devices/%s", user_id)
content = yield self.client.get_json(
destination=destination,
@ -455,7 +469,7 @@ class TransportLayerClient(object):
A dict containg the one-time keys.
"""
path = PREFIX + "/user/keys/claim"
path = _create_v1_path("/user/keys/claim")
content = yield self.client.post_json(
destination=destination,
@ -469,7 +483,7 @@ class TransportLayerClient(object):
@log_function
def get_missing_events(self, destination, room_id, earliest_events,
latest_events, limit, min_depth, timeout):
path = _create_path(PREFIX, "/get_missing_events/%s", room_id,)
path = _create_v1_path("/get_missing_events/%s", room_id,)
content = yield self.client.post_json(
destination=destination,
@ -489,7 +503,7 @@ class TransportLayerClient(object):
def get_group_profile(self, destination, group_id, requester_user_id):
"""Get a group profile
"""
path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
path = _create_v1_path("/groups/%s/profile", group_id,)
return self.client.get_json(
destination=destination,
@ -508,7 +522,7 @@ class TransportLayerClient(object):
requester_user_id (str)
content (dict): The new profile of the group
"""
path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
path = _create_v1_path("/groups/%s/profile", group_id,)
return self.client.post_json(
destination=destination,
@ -522,7 +536,7 @@ class TransportLayerClient(object):
def get_group_summary(self, destination, group_id, requester_user_id):
"""Get a group summary
"""
path = _create_path(PREFIX, "/groups/%s/summary", group_id,)
path = _create_v1_path("/groups/%s/summary", group_id,)
return self.client.get_json(
destination=destination,
@ -535,7 +549,7 @@ class TransportLayerClient(object):
def get_rooms_in_group(self, destination, group_id, requester_user_id):
"""Get all rooms in a group
"""
path = _create_path(PREFIX, "/groups/%s/rooms", group_id,)
path = _create_v1_path("/groups/%s/rooms", group_id,)
return self.client.get_json(
destination=destination,
@ -548,7 +562,7 @@ class TransportLayerClient(object):
content):
"""Add a room to a group
"""
path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
return self.client.post_json(
destination=destination,
@ -562,8 +576,8 @@ class TransportLayerClient(object):
config_key, content):
"""Update room in group
"""
path = _create_path(
PREFIX, "/groups/%s/room/%s/config/%s",
path = _create_v1_path(
"/groups/%s/room/%s/config/%s",
group_id, room_id, config_key,
)
@ -578,7 +592,7 @@ class TransportLayerClient(object):
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
"""Remove a room from a group
"""
path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
return self.client.delete_json(
destination=destination,
@ -591,7 +605,7 @@ class TransportLayerClient(object):
def get_users_in_group(self, destination, group_id, requester_user_id):
"""Get users in a group
"""
path = _create_path(PREFIX, "/groups/%s/users", group_id,)
path = _create_v1_path("/groups/%s/users", group_id,)
return self.client.get_json(
destination=destination,
@ -604,7 +618,7 @@ class TransportLayerClient(object):
def get_invited_users_in_group(self, destination, group_id, requester_user_id):
"""Get users that have been invited to a group
"""
path = _create_path(PREFIX, "/groups/%s/invited_users", group_id,)
path = _create_v1_path("/groups/%s/invited_users", group_id,)
return self.client.get_json(
destination=destination,
@ -617,8 +631,8 @@ class TransportLayerClient(object):
def accept_group_invite(self, destination, group_id, user_id, content):
"""Accept a group invite
"""
path = _create_path(
PREFIX, "/groups/%s/users/%s/accept_invite",
path = _create_v1_path(
"/groups/%s/users/%s/accept_invite",
group_id, user_id,
)
@ -633,7 +647,7 @@ class TransportLayerClient(object):
def join_group(self, destination, group_id, user_id, content):
"""Attempts to join a group
"""
path = _create_path(PREFIX, "/groups/%s/users/%s/join", group_id, user_id)
path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
return self.client.post_json(
destination=destination,
@ -646,7 +660,7 @@ class TransportLayerClient(object):
def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
"""Invite a user to a group
"""
path = _create_path(PREFIX, "/groups/%s/users/%s/invite", group_id, user_id)
path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
return self.client.post_json(
destination=destination,
@ -662,7 +676,7 @@ class TransportLayerClient(object):
invited.
"""
path = _create_path(PREFIX, "/groups/local/%s/users/%s/invite", group_id, user_id)
path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id)
return self.client.post_json(
destination=destination,
@ -676,7 +690,7 @@ class TransportLayerClient(object):
user_id, content):
"""Remove a user fron a group
"""
path = _create_path(PREFIX, "/groups/%s/users/%s/remove", group_id, user_id)
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
return self.client.post_json(
destination=destination,
@ -693,7 +707,7 @@ class TransportLayerClient(object):
kicked from the group.
"""
path = _create_path(PREFIX, "/groups/local/%s/users/%s/remove", group_id, user_id)
path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id)
return self.client.post_json(
destination=destination,
@ -708,7 +722,7 @@ class TransportLayerClient(object):
the attestations
"""
path = _create_path(PREFIX, "/groups/%s/renew_attestation/%s", group_id, user_id)
path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id)
return self.client.post_json(
destination=destination,
@ -723,12 +737,12 @@ class TransportLayerClient(object):
"""Update a room entry in a group summary
"""
if category_id:
path = _create_path(
PREFIX, "/groups/%s/summary/categories/%s/rooms/%s",
path = _create_v1_path(
"/groups/%s/summary/categories/%s/rooms/%s",
group_id, category_id, room_id,
)
else:
path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
return self.client.post_json(
destination=destination,
@ -744,12 +758,12 @@ class TransportLayerClient(object):
"""Delete a room entry in a group summary
"""
if category_id:
path = _create_path(
PREFIX + "/groups/%s/summary/categories/%s/rooms/%s",
path = _create_v1_path(
"/groups/%s/summary/categories/%s/rooms/%s",
group_id, category_id, room_id,
)
else:
path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
return self.client.delete_json(
destination=destination,
@ -762,7 +776,7 @@ class TransportLayerClient(object):
def get_group_categories(self, destination, group_id, requester_user_id):
"""Get all categories in a group
"""
path = _create_path(PREFIX, "/groups/%s/categories", group_id,)
path = _create_v1_path("/groups/%s/categories", group_id,)
return self.client.get_json(
destination=destination,
@ -775,7 +789,7 @@ class TransportLayerClient(object):
def get_group_category(self, destination, group_id, requester_user_id, category_id):
"""Get category info in a group
"""
path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
return self.client.get_json(
destination=destination,
@ -789,7 +803,7 @@ class TransportLayerClient(object):
content):
"""Update a category in a group
"""
path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
return self.client.post_json(
destination=destination,
@ -804,7 +818,7 @@ class TransportLayerClient(object):
category_id):
"""Delete a category in a group
"""
path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
return self.client.delete_json(
destination=destination,
@ -817,7 +831,7 @@ class TransportLayerClient(object):
def get_group_roles(self, destination, group_id, requester_user_id):
"""Get all roles in a group
"""
path = _create_path(PREFIX, "/groups/%s/roles", group_id,)
path = _create_v1_path("/groups/%s/roles", group_id,)
return self.client.get_json(
destination=destination,
@ -830,7 +844,7 @@ class TransportLayerClient(object):
def get_group_role(self, destination, group_id, requester_user_id, role_id):
"""Get a roles info
"""
path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
return self.client.get_json(
destination=destination,
@ -844,7 +858,7 @@ class TransportLayerClient(object):
content):
"""Update a role in a group
"""
path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
return self.client.post_json(
destination=destination,
@ -858,7 +872,7 @@ class TransportLayerClient(object):
def delete_group_role(self, destination, group_id, requester_user_id, role_id):
"""Delete a role in a group
"""
path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
return self.client.delete_json(
destination=destination,
@ -873,12 +887,12 @@ class TransportLayerClient(object):
"""Update a users entry in a group
"""
if role_id:
path = _create_path(
PREFIX, "/groups/%s/summary/roles/%s/users/%s",
path = _create_v1_path(
"/groups/%s/summary/roles/%s/users/%s",
group_id, role_id, user_id,
)
else:
path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
return self.client.post_json(
destination=destination,
@ -893,7 +907,7 @@ class TransportLayerClient(object):
content):
"""Sets the join policy for a group
"""
path = _create_path(PREFIX, "/groups/%s/settings/m.join_policy", group_id,)
path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id,)
return self.client.put_json(
destination=destination,
@ -909,12 +923,12 @@ class TransportLayerClient(object):
"""Delete a users entry in a group
"""
if role_id:
path = _create_path(
PREFIX, "/groups/%s/summary/roles/%s/users/%s",
path = _create_v1_path(
"/groups/%s/summary/roles/%s/users/%s",
group_id, role_id, user_id,
)
else:
path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
return self.client.delete_json(
destination=destination,
@ -927,7 +941,7 @@ class TransportLayerClient(object):
"""Get the groups a list of users are publicising
"""
path = PREFIX + "/get_groups_publicised"
path = _create_v1_path("/get_groups_publicised")
content = {"user_ids": user_ids}
@ -939,20 +953,43 @@ class TransportLayerClient(object):
)
def _create_path(prefix, path, *args):
"""Creates a path from the prefix, path template and args. Ensures that
all args are url encoded.
def _create_v1_path(path, *args):
"""Creates a path against V1 federation API from the path template and
args. Ensures that all args are url encoded.
Example:
_create_path(PREFIX, "/event/%s/", event_id)
_create_v1_path("/event/%s/", event_id)
Args:
prefix (str)
path (str): String template for the path
args: ([str]): Args to insert into path. Each arg will be url encoded
Returns:
str
"""
return prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args)
return (
FEDERATION_V1_PREFIX
+ path % tuple(urllib.parse.quote(arg, "") for arg in args)
)
def _create_v2_path(path, *args):
"""Creates a path against V2 federation API from the path template and
args. Ensures that all args are url encoded.
Example:
_create_v2_path("/event/%s/", event_id)
Args:
path (str): String template for the path
args: ([str]): Args to insert into path. Each arg will be url encoded
Returns:
str
"""
return (
FEDERATION_V2_PREFIX
+ path % tuple(urllib.parse.quote(arg, "") for arg in args)
)

View file

@ -21,8 +21,9 @@ import re
from twisted.internet import defer
import synapse
from synapse.api.constants import RoomVersions
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
@ -42,9 +43,20 @@ logger = logging.getLogger(__name__)
class TransportLayerServer(JsonResource):
"""Handles incoming federation HTTP requests"""
def __init__(self, hs):
def __init__(self, hs, servlet_groups=None):
"""Initialize the TransportLayerServer
Will by default register all servlets. For custom behaviour, pass in
a list of servlet_groups to register.
Args:
hs (synapse.server.HomeServer): homeserver
servlet_groups (list[str], optional): List of servlet groups to register.
Defaults to ``DEFAULT_SERVLET_GROUPS``.
"""
self.hs = hs
self.clock = hs.get_clock()
self.servlet_groups = servlet_groups
super(TransportLayerServer, self).__init__(hs, canonical_json=False)
@ -66,6 +78,7 @@ class TransportLayerServer(JsonResource):
resource=self,
ratelimiter=self.ratelimiter,
authenticator=self.authenticator,
servlet_groups=self.servlet_groups,
)
@ -227,6 +240,8 @@ class BaseFederationServlet(object):
"""
REQUIRE_AUTH = True
PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version
def __init__(self, handler, authenticator, ratelimiter, server_name):
self.handler = handler
self.authenticator = authenticator
@ -286,7 +301,7 @@ class BaseFederationServlet(object):
return new_func
def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH + "$")
pattern = re.compile("^" + self.PREFIX + self.PATH + "$")
for method in ("GET", "PUT", "POST"):
code = getattr(self, "on_%s" % (method), None)
@ -362,14 +377,6 @@ class FederationSendServlet(BaseFederationServlet):
defer.returnValue((code, response))
class FederationPullServlet(BaseFederationServlet):
PATH = "/pull/"
# This is for when someone asks us for everything since version X
def on_GET(self, origin, content, query):
return self.handler.on_pull_request(query["origin"][0], query["v"])
class FederationEventServlet(BaseFederationServlet):
PATH = "/event/(?P<event_id>[^/]*)/"
@ -474,7 +481,7 @@ class FederationSendLeaveServlet(BaseFederationServlet):
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, room_id, event_id):
content = yield self.handler.on_send_leave_request(origin, content)
content = yield self.handler.on_send_leave_request(origin, content, room_id)
defer.returnValue((200, content))
@ -492,18 +499,50 @@ class FederationSendJoinServlet(BaseFederationServlet):
def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
# match those given in content
content = yield self.handler.on_send_join_request(origin, content)
content = yield self.handler.on_send_join_request(origin, content, context)
defer.returnValue((200, content))
class FederationInviteServlet(BaseFederationServlet):
class FederationV1InviteServlet(BaseFederationServlet):
PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, context, event_id):
# We don't get a room version, so we have to assume its EITHER v1 or
# v2. This is "fine" as the only difference between V1 and V2 is the
# state resolution algorithm, and we don't use that for processing
# invites
content = yield self.handler.on_invite_request(
origin, content, room_version=RoomVersions.V1,
)
# V1 federation API is defined to return a content of `[200, {...}]`
# due to a historical bug.
defer.returnValue((200, (200, content)))
class FederationV2InviteServlet(BaseFederationServlet):
PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
# match those given in content
content = yield self.handler.on_invite_request(origin, content)
room_version = content["room_version"]
event = content["event"]
invite_room_state = content["invite_room_state"]
# Synapse expects invite_room_state to be in unsigned, as it is in v1
# API
event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state
content = yield self.handler.on_invite_request(
origin, event, room_version=room_version,
)
defer.returnValue((200, content))
@ -1262,7 +1301,6 @@ class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
FEDERATION_SERVLET_CLASSES = (
FederationSendServlet,
FederationPullServlet,
FederationEventServlet,
FederationStateServlet,
FederationStateIdsServlet,
@ -1273,7 +1311,8 @@ FEDERATION_SERVLET_CLASSES = (
FederationEventServlet,
FederationSendJoinServlet,
FederationSendLeaveServlet,
FederationInviteServlet,
FederationV1InviteServlet,
FederationV2InviteServlet,
FederationQueryAuthServlet,
FederationGetMissingEventsServlet,
FederationEventAuthServlet,
@ -1282,10 +1321,12 @@ FEDERATION_SERVLET_CLASSES = (
FederationClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
OpenIdUserInfo,
FederationVersionServlet,
)
OPENID_SERVLET_CLASSES = (
OpenIdUserInfo,
)
ROOM_LIST_CLASSES = (
PublicRoomList,
@ -1324,44 +1365,83 @@ GROUP_ATTESTATION_SERVLET_CLASSES = (
FederationGroupsRenewAttestaionServlet,
)
DEFAULT_SERVLET_GROUPS = (
"federation",
"room_list",
"group_server",
"group_local",
"group_attestation",
"openid",
)
def register_servlets(hs, resource, authenticator, ratelimiter):
for servletclass in FEDERATION_SERVLET_CLASSES:
servletclass(
handler=hs.get_federation_server(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
for servletclass in ROOM_LIST_CLASSES:
servletclass(
handler=hs.get_room_list_handler(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=None):
"""Initialize and register servlet classes.
for servletclass in GROUP_SERVER_SERVLET_CLASSES:
servletclass(
handler=hs.get_groups_server_handler(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
Will by default register all servlets. For custom behaviour, pass in
a list of servlet_groups to register.
for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
servletclass(
handler=hs.get_groups_local_handler(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
Args:
hs (synapse.server.HomeServer): homeserver
resource (TransportLayerServer): resource class to register to
authenticator (Authenticator): authenticator to use
ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
servlet_groups (list[str], optional): List of servlet groups to register.
Defaults to ``DEFAULT_SERVLET_GROUPS``.
"""
if not servlet_groups:
servlet_groups = DEFAULT_SERVLET_GROUPS
for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
servletclass(
handler=hs.get_groups_attestation_renewer(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
if "federation" in servlet_groups:
for servletclass in FEDERATION_SERVLET_CLASSES:
servletclass(
handler=hs.get_federation_server(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
if "openid" in servlet_groups:
for servletclass in OPENID_SERVLET_CLASSES:
servletclass(
handler=hs.get_federation_server(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
if "room_list" in servlet_groups:
for servletclass in ROOM_LIST_CLASSES:
servletclass(
handler=hs.get_room_list_handler(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
if "group_server" in servlet_groups:
for servletclass in GROUP_SERVER_SERVLET_CLASSES:
servletclass(
handler=hs.get_groups_server_handler(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
if "group_local" in servlet_groups:
for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
servletclass(
handler=hs.get_groups_local_handler(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
if "group_attestation" in servlet_groups:
for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
servletclass(
handler=hs.get_groups_attestation_renewer(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)

View file

@ -117,9 +117,6 @@ class Transaction(JsonEncodedObject):
"Require 'transaction_id' to construct a Transaction"
)
for p in pdus:
p.transaction_id = kwargs["transaction_id"]
kwargs["pdus"] = [p.get_pdu_json() for p in pdus]
return Transaction(**kwargs)

View file

@ -42,7 +42,7 @@ from signedjson.sign import sign_json
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.api.errors import RequestSendFailed, SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
from synapse.util.logcontext import run_in_background
@ -191,6 +191,11 @@ class GroupAttestionRenewer(object):
yield self.store.update_attestation_renewal(
group_id, user_id, attestation
)
except RequestSendFailed as e:
logger.warning(
"Failed to renew attestation of %r in %r: %s",
user_id, group_id, e,
)
except Exception:
logger.exception("Error renewing attestation of %r in %r",
user_id, group_id)

View file

@ -17,7 +17,6 @@ from .admin import AdminHandler
from .directory import DirectoryHandler
from .federation import FederationHandler
from .identity import IdentityHandler
from .register import RegistrationHandler
from .search import SearchHandler
@ -41,7 +40,6 @@ class Handlers(object):
"""
def __init__(self, hs):
self.registration_handler = RegistrationHandler(hs)
self.federation_handler = FederationHandler(hs)
self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs)

View file

@ -167,4 +167,4 @@ class BaseHandler(object):
ratelimit=False,
)
except Exception as e:
logger.warn("Error kicking guest user: %s" % (e,))
logger.exception("Error kicking guest user: %s" % (e,))

151
synapse/handlers/acme.py Normal file
View file

@ -0,0 +1,151 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import attr
from zope.interface import implementer
import twisted
import twisted.internet.error
from twisted.internet import defer
from twisted.python.filepath import FilePath
from twisted.python.url import URL
from twisted.web import server, static
from twisted.web.resource import Resource
from synapse.app import check_bind_error
logger = logging.getLogger(__name__)
try:
from txacme.interfaces import ICertificateStore
@attr.s
@implementer(ICertificateStore)
class ErsatzStore(object):
"""
A store that only stores in memory.
"""
certs = attr.ib(default=attr.Factory(dict))
def store(self, server_name, pem_objects):
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
return defer.succeed(None)
except ImportError:
# txacme is missing
pass
class AcmeHandler(object):
def __init__(self, hs):
self.hs = hs
self.reactor = hs.get_reactor()
self._acme_domain = hs.config.acme_domain
@defer.inlineCallbacks
def start_listening(self):
# Configure logging for txacme, if you need to debug
# from eliot import add_destinations
# from eliot.twisted import TwistedDestination
#
# add_destinations(TwistedDestination())
from txacme.challenges import HTTP01Responder
from txacme.service import AcmeIssuingService
from txacme.endpoint import load_or_create_client_key
from txacme.client import Client
from josepy.jwa import RS256
self._store = ErsatzStore()
responder = HTTP01Responder()
self._issuer = AcmeIssuingService(
cert_store=self._store,
client_creator=(
lambda: Client.from_url(
reactor=self.reactor,
url=URL.from_text(self.hs.config.acme_url),
key=load_or_create_client_key(
FilePath(self.hs.config.config_dir_path)
),
alg=RS256,
)
),
clock=self.reactor,
responders=[responder],
)
well_known = Resource()
well_known.putChild(b'acme-challenge', responder.resource)
responder_resource = Resource()
responder_resource.putChild(b'.well-known', well_known)
responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain'))
srv = server.Site(responder_resource)
bind_addresses = self.hs.config.acme_bind_addresses
for host in bind_addresses:
logger.info(
"Listening for ACME requests on %s:%i", host, self.hs.config.acme_port,
)
try:
self.reactor.listenTCP(
self.hs.config.acme_port,
srv,
interface=host,
)
except twisted.internet.error.CannotListenError as e:
check_bind_error(e, host, bind_addresses)
# Make sure we are registered to the ACME server. There's no public API
# for this, it is usually triggered by startService, but since we don't
# want it to control where we save the certificates, we have to reach in
# and trigger the registration machinery ourselves.
self._issuer._registered = False
yield self._issuer._ensure_registered()
@defer.inlineCallbacks
def provision_certificate(self):
logger.warning("Reprovisioning %s", self._acme_domain)
try:
yield self._issuer.issue_cert(self._acme_domain)
except Exception:
logger.exception("Fail!")
raise
logger.warning("Reprovisioned %s, saving.", self._acme_domain)
cert_chain = self._store.certs[self._acme_domain]
try:
with open(self.hs.config.tls_private_key_file, "wb") as private_key_file:
for x in cert_chain:
if x.startswith(b"-----BEGIN RSA PRIVATE KEY-----"):
private_key_file.write(x)
with open(self.hs.config.tls_certificate_file, "wb") as certificate_file:
for x in cert_chain:
if x.startswith(b"-----BEGIN CERTIFICATE-----"):
certificate_file.write(x)
except Exception:
logger.exception("Failed saving!")
raise
defer.returnValue(True)

View file

@ -59,6 +59,7 @@ class AuthHandler(BaseHandler):
LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.MSISDN: self._check_msisdn,
LoginType.DUMMY: self._check_dummy_auth,
LoginType.TERMS: self._check_terms_auth,
}
self.bcrypt_rounds = hs.config.bcrypt_rounds
@ -431,6 +432,9 @@ class AuthHandler(BaseHandler):
def _check_dummy_auth(self, authdict, _):
return defer.succeed(True)
def _check_terms_auth(self, authdict, _):
return defer.succeed(True)
@defer.inlineCallbacks
def _check_threepid(self, medium, authdict):
if 'threepid_creds' not in authdict:
@ -462,6 +466,22 @@ class AuthHandler(BaseHandler):
def _get_params_recaptcha(self):
return {"public_key": self.hs.config.recaptcha_public_key}
def _get_params_terms(self):
return {
"policies": {
"privacy_policy": {
"version": self.hs.config.user_consent_version,
"en": {
"name": self.hs.config.user_consent_policy_name,
"url": "%s_matrix/consent?v=%s" % (
self.hs.config.public_baseurl,
self.hs.config.user_consent_version,
),
},
},
},
}
def _auth_dict_for_flows(self, flows, session):
public_flows = []
for f in flows:
@ -469,6 +489,7 @@ class AuthHandler(BaseHandler):
get_params = {
LoginType.RECAPTCHA: self._get_params_recaptcha,
LoginType.TERMS: self._get_params_terms,
}
params = {}
@ -542,10 +563,10 @@ class AuthHandler(BaseHandler):
insensitively, but return None if there are multiple inexact matches.
Args:
(str) user_id: complete @user:id
(unicode|bytes) user_id: complete @user:id
Returns:
defer.Deferred: (str) canonical_user_id, or None if zero or
defer.Deferred: (unicode) canonical_user_id, or None if zero or
multiple matches
"""
res = yield self._find_user_id_and_pwd_hash(user_id)
@ -933,6 +954,15 @@ class MacaroonGenerator(object):
return macaroon.serialize()
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
"""
Args:
user_id (unicode):
duration_in_ms (int):
Returns:
unicode
"""
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()

View file

@ -20,7 +20,11 @@ from twisted.internet import defer
from synapse.api import errors
from synapse.api.constants import EventTypes
from synapse.api.errors import FederationDeniedError
from synapse.api.errors import (
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
)
from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
@ -504,13 +508,13 @@ class DeviceListEduUpdater(object):
origin = get_domain_from_id(user_id)
try:
result = yield self.federation.query_user_devices(origin, user_id)
except NotRetryingDestination:
except (
NotRetryingDestination, RequestSendFailed, HttpResponseException,
):
# TODO: Remember that we are now out of sync and try again
# later
logger.warn(
"Failed to handle device list update for %s,"
" we're not retrying the remote",
user_id,
"Failed to handle device list update for %s", user_id,
)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
@ -532,6 +536,25 @@ class DeviceListEduUpdater(object):
stream_id = result["stream_id"]
devices = result["devices"]
# If the remote server has more than ~1000 devices for this user
# we assume that something is going horribly wrong (e.g. a bot
# that logs in and creates a new device every time it tries to
# send a message). Maintaining lots of devices per user in the
# cache can cause serious performance issues as if this request
# takes more than 60s to complete, internal replication from the
# inbound federation worker to the synapse master may time out
# causing the inbound federation to fail and causing the remote
# server to retry, causing a DoS. So in this scenario we give
# up on storing the total list of devices and only handle the
# delta instead.
if len(devices) > 1000:
logger.warn(
"Ignoring device list snapshot for %s as it has >1K devs (%d)",
user_id, len(devices)
)
devices = []
yield self.store.update_remote_device_list_cache(
user_id, devices, stream_id,
)

View file

@ -57,8 +57,8 @@ class DirectoryHandler(BaseHandler):
# general association creation for both human users and app services
for wchar in string.whitespace:
if wchar in room_alias.localpart:
raise SynapseError(400, "Invalid characters in room alias")
if wchar in room_alias.localpart:
raise SynapseError(400, "Invalid characters in room alias")
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local")
@ -112,7 +112,9 @@ class DirectoryHandler(BaseHandler):
403, "This user is not permitted to create this alias",
)
if not self.config.is_alias_creation_allowed(user_id, room_alias.to_string()):
if not self.config.is_alias_creation_allowed(
user_id, room_id, room_alias.to_string(),
):
# Lets just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
@ -138,9 +140,30 @@ class DirectoryHandler(BaseHandler):
)
@defer.inlineCallbacks
def delete_association(self, requester, room_alias):
# association deletion for human users
def delete_association(self, requester, room_alias, send_event=True):
"""Remove an alias from the directory
(this is only meant for human users; AS users should call
delete_appservice_association)
Args:
requester (Requester):
room_alias (RoomAlias):
send_event (bool): Whether to send an updated m.room.aliases event.
Note that, if we delete the canonical alias, we will always attempt
to send an m.room.canonical_alias event
Returns:
Deferred[unicode]: room id that the alias used to point to
Raises:
NotFoundError: if the alias doesn't exist
AuthError: if the user doesn't have perms to delete the alias (ie, the user
is neither the creator of the alias, nor a server admin.
SynapseError: if the alias belongs to an AS
"""
user_id = requester.user.to_string()
try:
@ -168,10 +191,11 @@ class DirectoryHandler(BaseHandler):
room_id = yield self._delete_association(room_alias)
try:
yield self.send_room_alias_update_event(
requester,
room_id
)
if send_event:
yield self.send_room_alias_update_event(
requester,
room_id
)
yield self._update_canonical_alias(
requester,
@ -373,9 +397,9 @@ class DirectoryHandler(BaseHandler):
room_id (str)
visibility (str): "public" or "private"
"""
if not self.spam_checker.user_may_publish_room(
requester.user.to_string(), room_id
):
user_id = requester.user.to_string()
if not self.spam_checker.user_may_publish_room(user_id, room_id):
raise AuthError(
403,
"This user is not permitted to publish rooms to the room list"
@ -393,7 +417,24 @@ class DirectoryHandler(BaseHandler):
yield self.auth.check_can_change_room_list(room_id, requester.user)
yield self.store.set_room_is_public(room_id, visibility == "public")
making_public = visibility == "public"
if making_public:
room_aliases = yield self.store.get_aliases_for_room(room_id)
canonical_alias = yield self.store.get_canonical_alias_for_room(room_id)
if canonical_alias:
room_aliases.append(canonical_alias)
if not self.config.is_publishing_room_allowed(
user_id, room_id, room_aliases,
):
# Lets just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
raise SynapseError(
403, "Not allowed to publish room",
)
yield self.store.set_room_is_public(room_id, making_public)
@defer.inlineCallbacks
def edit_published_appservice_room_list(self, appservice_id, network_id,

View file

@ -19,7 +19,13 @@ from six import iteritems
from twisted.internet import defer
from synapse.api.errors import RoomKeysVersionError, StoreError, SynapseError
from synapse.api.errors import (
Codes,
NotFoundError,
RoomKeysVersionError,
StoreError,
SynapseError,
)
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@ -55,6 +61,8 @@ class E2eRoomKeysHandler(object):
room_id(string): room ID to get keys for, for None to get keys for all rooms
session_id(string): session ID to get keys for, for None to get keys for all
sessions
Raises:
NotFoundError: if the backup version does not exist
Returns:
A deferred list of dicts giving the session_data and message metadata for
these room keys.
@ -63,13 +71,19 @@ class E2eRoomKeysHandler(object):
# we deliberately take the lock to get keys so that changing the version
# works atomically
with (yield self._upload_linearizer.queue(user_id)):
# make sure the backup version exists
try:
yield self.store.get_e2e_room_keys_version_info(user_id, version)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown backup version")
else:
raise
results = yield self.store.get_e2e_room_keys(
user_id, version, room_id, session_id
)
if results['rooms'] == {}:
raise SynapseError(404, "No room_keys found")
defer.returnValue(results)
@defer.inlineCallbacks
@ -120,7 +134,7 @@ class E2eRoomKeysHandler(object):
}
Raises:
SynapseError: with code 404 if there are no versions defined
NotFoundError: if there are no versions defined
RoomKeysVersionError: if the uploaded version is not the current version
"""
@ -134,7 +148,7 @@ class E2eRoomKeysHandler(object):
version_info = yield self.store.get_e2e_room_keys_version_info(user_id)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Version '%s' not found" % (version,))
raise NotFoundError("Version '%s' not found" % (version,))
else:
raise
@ -148,7 +162,7 @@ class E2eRoomKeysHandler(object):
raise RoomKeysVersionError(current_version=version_info['version'])
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Version '%s' not found" % (version,))
raise NotFoundError("Version '%s' not found" % (version,))
else:
raise
@ -259,7 +273,7 @@ class E2eRoomKeysHandler(object):
version(str): Optional; if None gives the most recent version
otherwise a historical one.
Raises:
StoreError: code 404 if the requested backup version doesn't exist
NotFoundError: if the requested backup version doesn't exist
Returns:
A deferred of a info dict that gives the info about the new version.
@ -271,7 +285,13 @@ class E2eRoomKeysHandler(object):
"""
with (yield self._upload_linearizer.queue(user_id)):
res = yield self.store.get_e2e_room_keys_version_info(user_id, version)
try:
res = yield self.store.get_e2e_room_keys_version_info(user_id, version)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown backup version")
else:
raise
defer.returnValue(res)
@defer.inlineCallbacks
@ -282,8 +302,60 @@ class E2eRoomKeysHandler(object):
user_id(str): the user whose current backup version we're deleting
version(str): the version id of the backup being deleted
Raises:
StoreError: code 404 if this backup version doesn't exist
NotFoundError: if this backup version doesn't exist
"""
with (yield self._upload_linearizer.queue(user_id)):
yield self.store.delete_e2e_room_keys_version(user_id, version)
try:
yield self.store.delete_e2e_room_keys_version(user_id, version)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown backup version")
else:
raise
@defer.inlineCallbacks
def update_version(self, user_id, version, version_info):
"""Update the info about a given version of the user's backup
Args:
user_id(str): the user whose current backup version we're updating
version(str): the backup version we're updating
version_info(dict): the new information about the backup
Raises:
NotFoundError: if the requested backup version doesn't exist
Returns:
A deferred of an empty dict.
"""
if "version" not in version_info:
raise SynapseError(
400,
"Missing version in body",
Codes.MISSING_PARAM
)
if version_info["version"] != version:
raise SynapseError(
400,
"Version in body does not match",
Codes.INVALID_PARAM
)
with (yield self._upload_linearizer.queue(user_id)):
try:
old_info = yield self.store.get_e2e_room_keys_version_info(
user_id, version
)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown backup version")
else:
raise
if old_info["algorithm"] != version_info["algorithm"]:
raise SynapseError(
400,
"Algorithm does not match",
Codes.INVALID_PARAM
)
yield self.store.update_e2e_room_keys_version(user_id, version, version_info)
defer.returnValue({})

View file

@ -34,6 +34,7 @@ from synapse.api.constants import (
EventTypes,
Membership,
RejectedReason,
RoomVersions,
)
from synapse.api.errors import (
AuthError,
@ -43,10 +44,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
from synapse.crypto.event_signing import (
add_hashes_and_signatures,
compute_event_signature,
)
from synapse.crypto.event_signing import compute_event_signature
from synapse.events.validator import EventValidator
from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
@ -58,7 +56,6 @@ from synapse.types import UserID, get_domain_from_id
from synapse.util import logcontext, unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room
from synapse.util.frozenutils import unfreeze
from synapse.util.logutils import log_function
from synapse.util.retryutils import NotRetryingDestination
from synapse.visibility import filter_events_for_server
@ -105,7 +102,7 @@ class FederationHandler(BaseHandler):
self.hs = hs
self.store = hs.get_datastore() # type: synapse.storage.DataStore
self.store = hs.get_datastore()
self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
@ -202,27 +199,22 @@ class FederationHandler(BaseHandler):
self.room_queues[room_id].append((pdu, origin))
return
# If we're no longer in the room just ditch the event entirely. This
# is probably an old server that has come back and thinks we're still
# in the room (or we've been rejoined to the room by a state reset).
# If we're not in the room just ditch the event entirely. This is
# probably an old server that has come back and thinks we're still in
# the room (or we've been rejoined to the room by a state reset).
#
# If we were never in the room then maybe our database got vaped and
# we should check if we *are* in fact in the room. If we are then we
# can magically rejoin the room.
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
is_in_room = yield self.auth.check_host_in_room(
room_id,
self.server_name
)
if not is_in_room:
was_in_room = yield self.store.was_host_joined(
pdu.room_id, self.server_name,
logger.info(
"[%s %s] Ignoring PDU from %s as we're not in the room",
room_id, event_id, origin,
)
if was_in_room:
logger.info(
"[%s %s] Ignoring PDU from %s as we've left the room",
room_id, event_id, origin,
)
defer.returnValue(None)
defer.returnValue(None)
state = None
auth_chain = []
@ -239,7 +231,7 @@ class FederationHandler(BaseHandler):
room_id, event_id, min_depth,
)
prevs = {e_id for e_id, _ in pdu.prev_events}
prevs = set(pdu.prev_event_ids())
seen = yield self.store.have_seen_events(prevs)
if min_depth and pdu.depth < min_depth:
@ -347,6 +339,8 @@ class FederationHandler(BaseHandler):
room_id, event_id, p,
)
room_version = yield self.store.get_room_version(room_id)
with logcontext.nested_logging_context(p):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
@ -360,7 +354,7 @@ class FederationHandler(BaseHandler):
# we want the state *after* p; get_state_for_room returns the
# state *before* p.
remote_event = yield self.federation_client.get_pdu(
[origin], p, outlier=True,
[origin], p, room_version, outlier=True,
)
if remote_event is None:
@ -384,7 +378,6 @@ class FederationHandler(BaseHandler):
for x in remote_state:
event_map[x.event_id] = x
room_version = yield self.store.get_room_version(room_id)
state_map = yield resolve_events_with_store(
room_version, state_maps, event_map,
state_res_store=StateResolutionStore(self.store),
@ -557,86 +550,54 @@ class FederationHandler(BaseHandler):
room_id, event_id, event,
)
# FIXME (erikj): Awful hack to make the case where we are not currently
# in the room work
# If state and auth_chain are None, then we don't need to do this check
# as we already know we have enough state in the DB to handle this
# event.
if state and auth_chain and not event.internal_metadata.is_outlier():
is_in_room = yield self.auth.check_host_in_room(
room_id,
self.server_name
)
else:
is_in_room = True
event_ids = set()
if state:
event_ids |= {e.event_id for e in state}
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
seen_ids = yield self.store.have_seen_events(event_ids)
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.)
event_infos = []
for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids:
continue
e.internal_metadata.outlier = True
auth_ids = e.auth_event_ids()
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
event_infos.append({
"event": e,
"auth_events": auth,
})
seen_ids.add(e.event_id)
if not is_in_room:
logger.info(
"[%s %s] Got event for room we're not in",
room_id, event_id,
"[%s %s] persisting newly-received auth/state events %s",
room_id, event_id, [e["event"].event_id for e in event_infos]
)
yield self._handle_new_events(origin, event_infos)
try:
yield self._persist_auth_tree(
origin, auth_chain, state, event
)
except AuthError as e:
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=event_id,
)
else:
event_ids = set()
if state:
event_ids |= {e.event_id for e in state}
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
seen_ids = yield self.store.have_seen_events(event_ids)
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.)
event_infos = []
for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids:
continue
e.internal_metadata.outlier = True
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
event_infos.append({
"event": e,
"auth_events": auth,
})
seen_ids.add(e.event_id)
logger.info(
"[%s %s] persisting newly-received auth/state events %s",
room_id, event_id, [e["event"].event_id for e in event_infos]
)
yield self._handle_new_events(origin, event_infos)
try:
context = yield self._handle_new_event(
origin,
event,
state=state,
)
except AuthError as e:
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=event.event_id,
)
try:
context = yield self._handle_new_event(
origin,
event,
state=state,
)
except AuthError as e:
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=event.event_id,
)
room = yield self.store.get_room(room_id)
@ -692,6 +653,8 @@ class FederationHandler(BaseHandler):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
room_version = yield self.store.get_room_version(room_id)
events = yield self.federation_client.backfill(
dest,
room_id,
@ -726,7 +689,7 @@ class FederationHandler(BaseHandler):
edges = [
ev.event_id
for ev in events
if set(e_id for e_id, _ in ev.prev_events) - event_ids
if set(ev.prev_event_ids()) - event_ids
]
logger.info(
@ -753,7 +716,7 @@ class FederationHandler(BaseHandler):
required_auth = set(
a_id
for event in events + list(state_events.values()) + list(auth_events.values())
for a_id, _ in event.auth_events
for a_id in event.auth_event_ids()
)
auth_events.update({
e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
@ -769,7 +732,7 @@ class FederationHandler(BaseHandler):
auth_events.update(ret_events)
required_auth.update(
a_id for event in ret_events.values() for a_id, _ in event.auth_events
a_id for event in ret_events.values() for a_id in event.auth_event_ids()
)
missing_auth = required_auth - set(auth_events)
@ -785,6 +748,7 @@ class FederationHandler(BaseHandler):
self.federation_client.get_pdu,
[dest],
event_id,
room_version=room_version,
outlier=True,
timeout=10000,
)
@ -796,7 +760,7 @@ class FederationHandler(BaseHandler):
required_auth.update(
a_id
for event in results if event
for a_id, _ in event.auth_events
for a_id in event.auth_event_ids()
)
missing_auth = required_auth - set(auth_events)
@ -816,7 +780,7 @@ class FederationHandler(BaseHandler):
"auth_events": {
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in a.auth_events
for a_id in a.auth_event_ids()
if a_id in auth_events
}
})
@ -828,7 +792,7 @@ class FederationHandler(BaseHandler):
"auth_events": {
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in event_map[e_id].auth_events
for a_id in event_map[e_id].auth_event_ids()
if a_id in auth_events
}
})
@ -1041,17 +1005,17 @@ class FederationHandler(BaseHandler):
Raises:
SynapseError if the event does not pass muster
"""
if len(ev.prev_events) > 20:
if len(ev.prev_event_ids()) > 20:
logger.warn("Rejecting event %s which has %i prev_events",
ev.event_id, len(ev.prev_events))
ev.event_id, len(ev.prev_event_ids()))
raise SynapseError(
http_client.BAD_REQUEST,
"Too many prev_events",
)
if len(ev.auth_events) > 10:
if len(ev.auth_event_ids()) > 10:
logger.warn("Rejecting event %s which has %i auth_events",
ev.event_id, len(ev.auth_events))
ev.event_id, len(ev.auth_event_ids()))
raise SynapseError(
http_client.BAD_REQUEST,
"Too many auth_events",
@ -1076,7 +1040,7 @@ class FederationHandler(BaseHandler):
def on_event_auth(self, event_id):
event = yield self.store.get_event(event_id)
auth = yield self.store.get_auth_chain(
[auth_id for auth_id, _ in event.auth_events],
[auth_id for auth_id in event.auth_event_ids()],
include_given=True
)
defer.returnValue([e for e in auth])
@ -1097,7 +1061,7 @@ class FederationHandler(BaseHandler):
"""
logger.debug("Joining %s to %s", joinee, room_id)
origin, event = yield self._make_and_verify_event(
origin, event, event_format_version = yield self._make_and_verify_event(
target_hosts,
room_id,
joinee,
@ -1120,7 +1084,6 @@ class FederationHandler(BaseHandler):
handled_events = set()
try:
event = self._sign_event(event)
# Try the host we successfully got a response to /make_join/
# request first.
try:
@ -1128,7 +1091,9 @@ class FederationHandler(BaseHandler):
target_hosts.insert(0, origin)
except ValueError:
pass
ret = yield self.federation_client.send_join(target_hosts, event)
ret = yield self.federation_client.send_join(
target_hosts, event, event_format_version,
)
origin = ret["origin"]
state = ret["state"]
@ -1201,13 +1166,18 @@ class FederationHandler(BaseHandler):
"""
event_content = {"membership": Membership.JOIN}
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"content": event_content,
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
})
room_version = yield self.store.get_room_version(room_id)
builder = self.event_builder_factory.new(
room_version,
{
"type": EventTypes.Member,
"content": event_content,
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
}
)
try:
event, context = yield self.event_creation_handler.create_new_client_event(
@ -1219,7 +1189,9 @@ class FederationHandler(BaseHandler):
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
yield self.auth.check_from_context(event, context, do_sig_check=False)
yield self.auth.check_from_context(
room_version, event, context, do_sig_check=False,
)
defer.returnValue(event)
@ -1324,11 +1296,11 @@ class FederationHandler(BaseHandler):
)
event.internal_metadata.outlier = True
event.internal_metadata.invite_from_remote = True
event.internal_metadata.out_of_band_membership = True
event.signatures.update(
compute_event_signature(
event,
event.get_pdu_json(),
self.hs.hostname,
self.hs.config.signing_key[0]
)
@ -1341,7 +1313,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
origin, event = yield self._make_and_verify_event(
origin, event, event_format_version = yield self._make_and_verify_event(
target_hosts,
room_id,
user_id,
@ -1350,7 +1322,7 @@ class FederationHandler(BaseHandler):
# Mark as outlier as we don't have any state for this event; we're not
# even in the room.
event.internal_metadata.outlier = True
event = self._sign_event(event)
event.internal_metadata.out_of_band_membership = True
# Try the host that we succesfully called /make_leave/ on first for
# the /send_leave/ request.
@ -1373,7 +1345,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
content={}, params=None):
origin, pdu = yield self.federation_client.make_membership_event(
origin, event, format_ver = yield self.federation_client.make_membership_event(
target_hosts,
room_id,
user_id,
@ -1382,9 +1354,7 @@ class FederationHandler(BaseHandler):
params=params,
)
logger.debug("Got response to make_%s: %s", membership, pdu)
event = pdu
logger.debug("Got response to make_%s: %s", membership, event)
# We should assert some things.
# FIXME: Do this in a nicer way
@ -1392,28 +1362,7 @@ class FederationHandler(BaseHandler):
assert(event.user_id == user_id)
assert(event.state_key == user_id)
assert(event.room_id == room_id)
defer.returnValue((origin, event))
def _sign_event(self, event):
event.internal_metadata.outlier = False
builder = self.event_builder_factory.new(
unfreeze(event.get_pdu_json())
)
builder.event_id = self.event_builder_factory.create_event_id()
builder.origin = self.hs.hostname
if not hasattr(event, "signatures"):
builder.signatures = {}
add_hashes_and_signatures(
builder,
self.hs.hostname,
self.hs.config.signing_key[0],
)
return builder.build()
defer.returnValue((origin, event, format_ver))
@defer.inlineCallbacks
@log_function
@ -1422,13 +1371,17 @@ class FederationHandler(BaseHandler):
leave event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
"""
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"content": {"membership": Membership.LEAVE},
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
})
room_version = yield self.store.get_room_version(room_id)
builder = self.event_builder_factory.new(
room_version,
{
"type": EventTypes.Member,
"content": {"membership": Membership.LEAVE},
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
}
)
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
@ -1437,7 +1390,9 @@ class FederationHandler(BaseHandler):
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request`
yield self.auth.check_from_context(event, context, do_sig_check=False)
yield self.auth.check_from_context(
room_version, event, context, do_sig_check=False,
)
except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e)
raise e
@ -1696,9 +1651,16 @@ class FederationHandler(BaseHandler):
create_event = e
break
if create_event is None:
# If the state doesn't have a create event then the room is
# invalid, and it would fail auth checks anyway.
raise SynapseError(400, "No create event in state")
room_version = create_event.content.get("room_version", RoomVersions.V1)
missing_auth_events = set()
for e in itertools.chain(auth_events, state, [event]):
for e_id, _ in e.auth_events:
for e_id in e.auth_event_ids():
if e_id not in event_map:
missing_auth_events.add(e_id)
@ -1706,6 +1668,7 @@ class FederationHandler(BaseHandler):
m_ev = yield self.federation_client.get_pdu(
[origin],
e_id,
room_version=room_version,
outlier=True,
timeout=10000,
)
@ -1717,14 +1680,14 @@ class FederationHandler(BaseHandler):
for e in itertools.chain(auth_events, state, [event]):
auth_for_e = {
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
for e_id, _ in e.auth_events
for e_id in e.auth_event_ids()
if e_id in event_map
}
if create_event:
auth_for_e[(EventTypes.Create, "")] = create_event
try:
self.auth.check(e, auth_events=auth_for_e)
self.auth.check(room_version, e, auth_events=auth_for_e)
except SynapseError as err:
# we may get SynapseErrors here as well as AuthErrors. For
# instance, there are a couple of (ancient) events in some
@ -1785,10 +1748,10 @@ class FederationHandler(BaseHandler):
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_events:
if len(event.prev_events) == 1 and event.depth < 5:
if event.type == EventTypes.Member and not event.auth_event_ids():
if len(event.prev_event_ids()) == 1 and event.depth < 5:
c = yield self.store.get_event(
event.prev_events[0][0],
event.prev_event_ids()[0],
allow_none=True,
)
if c and c.type == EventTypes.Create:
@ -1835,7 +1798,7 @@ class FederationHandler(BaseHandler):
# Now get the current auth_chain for the event.
local_auth_chain = yield self.store.get_auth_chain(
[auth_id for auth_id, _ in event.auth_events],
[auth_id for auth_id in event.auth_event_ids()],
include_given=True
)
@ -1891,7 +1854,7 @@ class FederationHandler(BaseHandler):
"""
# Check if we have all the auth events.
current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(e_id for e_id, _ in event.auth_events)
event_auth_events = set(event.auth_event_ids())
if event.is_state():
event_key = (event.type, event.state_key)
@ -1935,7 +1898,7 @@ class FederationHandler(BaseHandler):
continue
try:
auth_ids = [e_id for e_id, _ in e.auth_events]
auth_ids = e.auth_event_ids()
auth = {
(e.type, e.state_key): e for e in remote_auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
@ -1956,7 +1919,7 @@ class FederationHandler(BaseHandler):
pass
have_events = yield self.store.get_seen_events_with_rejections(
[e_id for e_id, _ in event.auth_events]
event.auth_event_ids()
)
seen_events = set(have_events.keys())
except Exception:
@ -1968,6 +1931,8 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
room_version = yield self.store.get_room_version(event.room_id)
if different_auth and not event.internal_metadata.is_outlier():
# Do auth conflict res.
logger.info("Different auth: %s", different_auth)
@ -1992,8 +1957,6 @@ class FederationHandler(BaseHandler):
(d.type, d.state_key): d for d in different_events if d
})
room_version = yield self.store.get_room_version(event.room_id)
new_state = yield self.state_handler.resolve_events(
room_version,
[list(local_view.values()), list(remote_view.values())],
@ -2058,7 +2021,7 @@ class FederationHandler(BaseHandler):
continue
try:
auth_ids = [e_id for e_id, _ in ev.auth_events]
auth_ids = ev.auth_event_ids()
auth = {
(e.type, e.state_key): e
for e in result["auth_chain"]
@ -2093,7 +2056,7 @@ class FederationHandler(BaseHandler):
)
try:
self.auth.check(event, auth_events=auth_events)
self.auth.check(room_version, event, auth_events=auth_events)
except AuthError as e:
logger.warn("Failed auth resolution for %r because %s", event, e)
raise e
@ -2250,7 +2213,7 @@ class FederationHandler(BaseHandler):
missing_remote_ids = [e.event_id for e in missing_remotes]
base_remote_rejected = list(missing_remotes)
for e in missing_remotes:
for e_id, _ in e.auth_events:
for e_id in e.auth_event_ids():
if e_id in missing_remote_ids:
try:
base_remote_rejected.remove(e)
@ -2316,18 +2279,26 @@ class FederationHandler(BaseHandler):
}
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
room_version = yield self.store.get_room_version(room_id)
builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_builder(builder)
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder
)
event, context = yield self.add_display_name_to_third_party_invite(
event_dict, event, context
room_version, event_dict, event, context
)
EventValidator().validate_new(event)
# We need to tell the transaction queue to send this out, even
# though the sender isn't a local user.
event.internal_metadata.send_on_behalf_of = self.hs.hostname
try:
yield self.auth.check_from_context(event, context)
yield self.auth.check_from_context(room_version, event, context)
except AuthError as e:
logger.warn("Denying new third party invite %r because %s", event, e)
raise e
@ -2354,23 +2325,31 @@ class FederationHandler(BaseHandler):
Returns:
Deferred: resolves (to None)
"""
builder = self.event_builder_factory.new(event_dict)
room_version = yield self.store.get_room_version(room_id)
# NB: event_dict has a particular specced format we might need to fudge
# if we change event formats too much.
builder = self.event_builder_factory.new(room_version, event_dict)
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
event, context = yield self.add_display_name_to_third_party_invite(
event_dict, event, context
room_version, event_dict, event, context
)
try:
self.auth.check_from_context(event, context)
self.auth.check_from_context(room_version, event, context)
except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, context)
# We need to tell the transaction queue to send this out, even
# though the sender isn't a local user.
event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender)
# XXX we send the invite here, but send_membership_event also sends it,
# so we end up making two requests. I think this is redundant.
returned_invite = yield self.send_invite(origin, event)
@ -2381,7 +2360,8 @@ class FederationHandler(BaseHandler):
yield member_handler.send_membership_event(None, event, context)
@defer.inlineCallbacks
def add_display_name_to_third_party_invite(self, event_dict, event, context):
def add_display_name_to_third_party_invite(self, room_version, event_dict,
event, context):
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"]
@ -2405,11 +2385,12 @@ class FederationHandler(BaseHandler):
# auth checks. If we need the invite and don't have it then the
# auth check code will explode appropriately.
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_builder(builder)
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
EventValidator().validate_new(event)
defer.returnValue((event, context))
@defer.inlineCallbacks

View file

@ -20,7 +20,7 @@ from six import iteritems
from twisted.internet import defer
from synapse.api.errors import HttpResponseException, SynapseError
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.types import get_domain_from_id
logger = logging.getLogger(__name__)
@ -46,13 +46,19 @@ def _create_rerouter(func_name):
# when the remote end responds with things like 403 Not
# In Group, we can communicate that to the client instead
# of a 500.
def h(failure):
def http_response_errback(failure):
failure.trap(HttpResponseException)
e = failure.value
if e.code == 403:
raise e.to_synapse_error()
return failure
d.addErrback(h)
def request_failed_errback(failure):
failure.trap(RequestSendFailed)
raise SynapseError(502, "Failed to contact group server")
d.addErrback(http_response_errback)
d.addErrback(request_failed_errback)
return d
return f

View file

@ -167,18 +167,21 @@ class IdentityHandler(BaseHandler):
"mxid": mxid,
"threepid": threepid,
}
headers = {}
# we abuse the federation http client to sign the request, but we have to send it
# using the normal http client since we don't want the SRV lookup and want normal
# 'browser-like' HTTPS.
self.federation_http_client.sign_request(
auth_headers = self.federation_http_client.build_auth_headers(
destination=None,
method='POST',
url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'),
headers_dict=headers,
content=content,
destination_is=id_server,
)
headers = {
b"Authorization": auth_headers,
}
try:
yield self.http_client.post_json_get_json(
url,

View file

@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
from twisted.internet.defer import succeed
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.constants import EventTypes, Membership, RoomVersions
from synapse.api.errors import (
AuthError,
Codes,
@ -31,7 +31,6 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.api.urls import ConsentURIBuilder
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
@ -278,9 +277,17 @@ class EventCreationHandler(object):
"""
yield self.auth.check_auth_blocking(requester.user.to_string())
builder = self.event_builder_factory.new(event_dict)
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version = event_dict["content"]["room_version"]
else:
try:
room_version = yield self.store.get_room_version(event_dict["room_id"])
except NotFoundError:
raise AuthError(403, "Unknown room")
self.validator.validate_new(builder)
builder = self.event_builder_factory.new(room_version, event_dict)
self.validator.validate_builder(builder)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
@ -318,6 +325,8 @@ class EventCreationHandler(object):
prev_events_and_hashes=prev_events_and_hashes,
)
self.validator.validate_new(event)
defer.returnValue((event, context))
def _is_exempt_from_privacy_policy(self, builder, requester):
@ -427,6 +436,9 @@ class EventCreationHandler(object):
if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context)
logger.info(
"Not bothering to persist duplicate state event %s", event.event_id,
)
if prev_state is not None:
defer.returnValue(prev_state)
@ -532,40 +544,19 @@ class EventCreationHandler(object):
prev_events_and_hashes = \
yield self.store.get_prev_events_for_room(builder.room_id)
if prev_events_and_hashes:
depth = max([d for _, _, d in prev_events_and_hashes]) + 1
# we cap depth of generated events, to ensure that they are not
# rejected by other servers (and so that they can be persisted in
# the db)
depth = min(depth, MAX_DEPTH)
else:
depth = 1
prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in prev_events_and_hashes
]
builder.prev_events = prev_events
builder.depth = depth
context = yield self.state.compute_event_context(builder)
event = yield builder.build(
prev_event_ids=[p for p, _ in prev_events],
)
context = yield self.state.compute_event_context(event)
if requester:
context.app_service = requester.app_service
if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
)
yield self.auth.add_auth_events(builder, context)
signing_key = self.hs.config.signing_key[0]
add_hashes_and_signatures(
builder, self.server_name, signing_key
)
event = builder.build()
self.validator.validate_new(event)
logger.debug(
"Created event %s",
@ -600,8 +591,13 @@ class EventCreationHandler(object):
extra_users (list(UserID)): Any extra users to notify about event
"""
if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""):
room_version = event.content.get("room_version", RoomVersions.V1)
else:
room_version = yield self.store.get_room_version(event.room_id)
try:
yield self.auth.check_from_context(event, context)
yield self.auth.check_from_context(room_version, event, context)
except AuthError as err:
logger.warn("Denying new event %r because %s", event, err)
raise err
@ -749,7 +745,8 @@ class EventCreationHandler(object):
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
if self.auth.check_redaction(event, auth_events=auth_events):
room_version = yield self.store.get_room_version(event.room_id)
if self.auth.check_redaction(room_version, event, auth_events=auth_events):
original_event = yield self.store.get_event(
event.redacts,
check_redacted=False,
@ -763,6 +760,9 @@ class EventCreationHandler(object):
"You don't have permission to redact events"
)
# We've already checked.
event.internal_metadata.recheck_redaction = False
if event.type == EventTypes.Create:
prev_state_ids = yield context.get_prev_state_ids(self.store)
if prev_state_ids:

View file

@ -235,6 +235,17 @@ class PaginationHandler(object):
"room_key", next_key
)
if events:
if event_filter:
events = event_filter.filter(events)
events = yield filter_events_for_client(
self.store,
user_id,
events,
is_peeking=(member_event_id is None),
)
if not events:
defer.returnValue({
"chunk": [],
@ -242,18 +253,8 @@ class PaginationHandler(object):
"end": next_token.to_string(),
})
if event_filter:
events = event_filter.filter(events)
events = yield filter_events_for_client(
self.store,
user_id,
events,
is_peeking=(member_event_id is None),
)
state = None
if event_filter and event_filter.lazy_load_members():
if event_filter and event_filter.lazy_load_members() and len(events) > 0:
# TODO: remove redundant members
# FIXME: we also care about invite targets etc.

View file

@ -16,8 +16,8 @@ import logging
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
from synapse.util import logcontext
from ._base import BaseHandler
@ -59,7 +59,9 @@ class ReceiptsHandler(BaseHandler):
if is_new:
# fire off a process in the background to send the receipt to
# remote servers
self._push_remotes([receipt])
run_as_background_process(
'push_receipts_to_remotes', self._push_remotes, receipt
)
@defer.inlineCallbacks
def _received_remote_receipt(self, origin, content):
@ -125,44 +127,42 @@ class ReceiptsHandler(BaseHandler):
defer.returnValue(True)
@logcontext.preserve_fn # caller should not yield on this
@defer.inlineCallbacks
def _push_remotes(self, receipts):
"""Given a list of receipts, works out which remote servers should be
def _push_remotes(self, receipt):
"""Given a receipt, works out which remote servers should be
poked and pokes them.
"""
try:
# TODO: Some of this stuff should be coallesced.
for receipt in receipts:
room_id = receipt["room_id"]
receipt_type = receipt["receipt_type"]
user_id = receipt["user_id"]
event_ids = receipt["event_ids"]
data = receipt["data"]
# TODO: optimise this to move some of the work to the workers.
room_id = receipt["room_id"]
receipt_type = receipt["receipt_type"]
user_id = receipt["user_id"]
event_ids = receipt["event_ids"]
data = receipt["data"]
users = yield self.state.get_current_user_in_room(room_id)
remotedomains = set(get_domain_from_id(u) for u in users)
remotedomains = remotedomains.copy()
remotedomains.discard(self.server_name)
users = yield self.state.get_current_user_in_room(room_id)
remotedomains = set(get_domain_from_id(u) for u in users)
remotedomains = remotedomains.copy()
remotedomains.discard(self.server_name)
logger.debug("Sending receipt to: %r", remotedomains)
logger.debug("Sending receipt to: %r", remotedomains)
for domain in remotedomains:
self.federation.send_edu(
destination=domain,
edu_type="m.receipt",
content={
room_id: {
receipt_type: {
user_id: {
"event_ids": event_ids,
"data": data,
}
for domain in remotedomains:
self.federation.send_edu(
destination=domain,
edu_type="m.receipt",
content={
room_id: {
receipt_type: {
user_id: {
"event_ids": event_ids,
"data": data,
}
},
}
},
key=(room_id, receipt_type, user_id),
)
},
key=(room_id, receipt_type, user_id),
)
except Exception:
logger.exception("Error pushing receipts to remote servers")

View file

@ -19,6 +19,7 @@ import logging
from twisted.internet import defer
from synapse import types
from synapse.api.constants import LoginType
from synapse.api.errors import (
AuthError,
Codes,
@ -26,7 +27,14 @@ from synapse.api.errors import (
RegistrationError,
SynapseError,
)
from synapse.config.server import is_threepid_reserved
from synapse.http.client import CaptchaServerHttpClient
from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet
from synapse.replication.http.register import (
ReplicationPostRegisterActionsServlet,
ReplicationRegisterServlet,
)
from synapse.types import RoomAlias, RoomID, UserID, create_requester
from synapse.util.async_helpers import Linearizer
from synapse.util.threepids import check_3pid_allowed
@ -50,8 +58,8 @@ class RegistrationHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
self.room_creation_handler = self.hs.get_room_creation_handler()
self.captcha_client = CaptchaServerHttpClient(hs)
self.identity_handler = self.hs.get_handlers().identity_handler
self._next_generated_user_id = None
@ -62,6 +70,18 @@ class RegistrationHandler(BaseHandler):
)
self._server_notices_mxid = hs.config.server_notices_mxid
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = (
RegisterDeviceReplicationServlet.make_client(hs)
)
self._post_registration_client = (
ReplicationPostRegisterActionsServlet.make_client(hs)
)
else:
self.device_handler = hs.get_device_handler()
self.pusher_pool = hs.get_pusherpool()
@defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None,
assigned_user_id=None):
@ -127,6 +147,8 @@ class RegistrationHandler(BaseHandler):
make_guest=False,
admin=False,
threepid=None,
user_type=None,
default_display_name=None,
):
"""Registers a new client on the server.
@ -141,6 +163,10 @@ class RegistrationHandler(BaseHandler):
since it offers no means of associating a device_id with the
access_token. Instead you should call auth_handler.issue_access_token
after registration.
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
default_display_name (unicode|None): if set, the new user's displayname
will be set to this. Defaults to 'localpart'.
Returns:
A tuple of (user_id, access_token).
Raises:
@ -150,7 +176,7 @@ class RegistrationHandler(BaseHandler):
yield self.auth.check_auth_blocking(threepid=threepid)
password_hash = None
if password:
password_hash = yield self.auth_handler().hash(password)
password_hash = yield self._auth_handler.hash(password)
if localpart:
yield self.check_username(localpart, guest_access_token=guest_access_token)
@ -170,20 +196,25 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
if was_guest:
# If the user was a guest then they already have a profile
default_display_name = None
elif default_display_name is None:
default_display_name = localpart
token = None
if generate_token:
token = self.macaroon_gen.generate_access_token(user_id)
yield self.store.register(
yield self._register_with_store(
user_id=user_id,
token=token,
password_hash=password_hash,
was_guest=was_guest,
make_guest=make_guest,
create_profile_with_localpart=(
# If the user was a guest then they already have a profile
None if was_guest else user.localpart
),
create_profile_with_displayname=default_display_name,
admin=admin,
user_type=user_type,
)
if self.hs.config.user_directory_search_all_users:
@ -204,13 +235,15 @@ class RegistrationHandler(BaseHandler):
yield self.check_user_id_not_appservice_exclusive(user_id)
if generate_token:
token = self.macaroon_gen.generate_access_token(user_id)
if default_display_name is None:
default_display_name = localpart
try:
yield self.store.register(
yield self._register_with_store(
user_id=user_id,
token=token,
password_hash=password_hash,
make_guest=make_guest,
create_profile_with_localpart=user.localpart,
create_profile_with_displayname=default_display_name,
)
except SynapseError:
# if user id is taken, just generate another
@ -218,16 +251,34 @@ class RegistrationHandler(BaseHandler):
user_id = None
token = None
attempts += 1
if not self.hs.config.user_consent_at_registration:
yield self._auto_join_rooms(user_id)
defer.returnValue((user_id, token))
@defer.inlineCallbacks
def _auto_join_rooms(self, user_id):
"""Automatically joins users to auto join rooms - creating the room in the first place
if the user is the first to be created.
Args:
user_id(str): The user to join
"""
# auto-join the user to any rooms we're supposed to dump them into
fake_requester = create_requester(user_id)
# try to create the room if we're the first user on the server
# try to create the room if we're the first real user on the server. Note
# that an auto-generated support user is not a real user and will never be
# the user to create the room
should_auto_create_rooms = False
if self.hs.config.autocreate_auto_join_rooms:
is_support = yield self.store.is_support_user(user_id)
# There is an edge case where the first user is the support user, then
# the room is never created, though this seems unlikely and
# recoverable from given the support user being involved in the first
# place.
if self.hs.config.autocreate_auto_join_rooms and not is_support:
count = yield self.store.count_all_users()
should_auto_create_rooms = count == 1
for r in self.hs.config.auto_join_rooms:
try:
if should_auto_create_rooms:
@ -241,7 +292,10 @@ class RegistrationHandler(BaseHandler):
else:
# create room expects the localpart of the room alias
room_alias_localpart = room_alias.localpart
yield self.room_creation_handler.create_room(
# getting the RoomCreationHandler during init gives a dependency
# loop
yield self.hs.get_room_creation_handler().create_room(
fake_requester,
config={
"preset": "public_chat",
@ -254,10 +308,15 @@ class RegistrationHandler(BaseHandler):
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
# We used to generate default identicons here, but nowadays
# we want clients to generate their own as part of their branding
# rather than there being consistent matrix-wide ones, so we don't.
defer.returnValue((user_id, token))
@defer.inlineCallbacks
def post_consent_actions(self, user_id):
"""A series of registration actions that can only be carried out once consent
has been granted
Args:
user_id (str): The user to join
"""
yield self._auto_join_rooms(user_id)
@defer.inlineCallbacks
def appservice_register(self, user_localpart, as_token):
@ -278,11 +337,11 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service
)
yield self.store.register(
yield self._register_with_store(
user_id=user_id,
password_hash="",
appservice_id=service_id,
create_profile_with_localpart=user.localpart,
create_profile_with_displayname=user.localpart,
)
defer.returnValue(user_id)
@ -309,35 +368,6 @@ class RegistrationHandler(BaseHandler):
else:
logger.info("Valid captcha entered from %s", ip)
@defer.inlineCallbacks
def register_saml2(self, localpart):
"""
Registers email_id as SAML2 Based Auth.
"""
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
"User ID can only contain characters a-z, 0-9, or '=_-./'",
)
yield self.auth.check_auth_blocking()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_not_appservice_exclusive(user_id)
token = self.macaroon_gen.generate_access_token(user_id)
try:
yield self.store.register(
user_id=user_id,
token=token,
password_hash=None,
create_profile_with_localpart=user.localpart,
)
except Exception as e:
yield self.store.add_access_token_to_user(user_id, token)
# Ignore Registration errors
logger.exception(e)
defer.returnValue((user_id, token))
@defer.inlineCallbacks
def register_email(self, threepidCreds):
"""
@ -350,8 +380,7 @@ class RegistrationHandler(BaseHandler):
logger.info("validating threepidcred sid %s on id server %s",
c['sid'], c['idServer'])
try:
identity_handler = self.hs.get_handlers().identity_handler
threepid = yield identity_handler.threepid_from_creds(c)
threepid = yield self.identity_handler.threepid_from_creds(c)
except Exception:
logger.exception("Couldn't validate 3pid")
raise RegistrationError(400, "Couldn't validate 3pid")
@ -375,9 +404,8 @@ class RegistrationHandler(BaseHandler):
# Now we have a matrix ID, bind it to the threepids we were given
for c in threepidCreds:
identity_handler = self.hs.get_handlers().identity_handler
# XXX: This should be a deferred list, shouldn't it?
yield identity_handler.bind_threepid(c, user_id)
yield self.identity_handler.bind_threepid(c, user_id)
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
# don't allow people to register the server notices mxid
@ -485,11 +513,11 @@ class RegistrationHandler(BaseHandler):
token = self.macaroon_gen.generate_access_token(user_id)
if need_register:
yield self.store.register(
yield self._register_with_store(
user_id=user_id,
token=token,
password_hash=password_hash,
create_profile_with_localpart=user.localpart,
create_profile_with_displayname=user.localpart,
)
else:
yield self._auth_handler.delete_access_tokens_for_user(user_id)
@ -503,9 +531,6 @@ class RegistrationHandler(BaseHandler):
defer.returnValue((user_id, token))
def auth_handler(self):
return self.hs.get_auth_handler()
@defer.inlineCallbacks
def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
"""Get a guest access token for a 3PID, creating a guest account if
@ -564,3 +589,275 @@ class RegistrationHandler(BaseHandler):
action="join",
ratelimit=False,
)
def _register_with_store(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_displayname=None, admin=False,
user_type=None):
"""Register user in the datastore.
Args:
user_id (str): The desired user ID to register.
token (str): The desired access token to use for this user. If this
is not None, the given access token is associated with the user
id.
password_hash (str|None): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest,
false to add a regular user account.
appservice_id (str|None): The ID of the appservice registering the user.
create_profile_with_displayname (unicode|None): Optionally create a
profile for the user, setting their displayname to the given value
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
Returns:
Deferred
"""
if self.hs.config.worker_app:
return self._register_client(
user_id=user_id,
token=token,
password_hash=password_hash,
was_guest=was_guest,
make_guest=make_guest,
appservice_id=appservice_id,
create_profile_with_displayname=create_profile_with_displayname,
admin=admin,
user_type=user_type,
)
else:
return self.store.register(
user_id=user_id,
token=token,
password_hash=password_hash,
was_guest=was_guest,
make_guest=make_guest,
appservice_id=appservice_id,
create_profile_with_displayname=create_profile_with_displayname,
admin=admin,
user_type=user_type,
)
@defer.inlineCallbacks
def register_device(self, user_id, device_id, initial_display_name,
is_guest=False):
"""Register a device for a user and generate an access token.
Args:
user_id (str): full canonical @user:id
device_id (str|None): The device ID to check, or None to generate
a new one.
initial_display_name (str|None): An optional display name for the
device.
is_guest (bool): Whether this is a guest account
Returns:
defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
"""
if self.hs.config.worker_app:
r = yield self._register_device_client(
user_id=user_id,
device_id=device_id,
initial_display_name=initial_display_name,
is_guest=is_guest,
)
defer.returnValue((r["device_id"], r["access_token"]))
else:
device_id = yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if is_guest:
access_token = self.macaroon_gen.generate_access_token(
user_id, ["guest = true"]
)
else:
access_token = yield self._auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id,
)
defer.returnValue((device_id, access_token))
@defer.inlineCallbacks
def post_registration_actions(self, user_id, auth_result, access_token,
bind_email, bind_msisdn):
"""A user has completed registration
Args:
user_id (str): The user ID that consented
auth_result (dict): The authenticated credentials of the newly
registered user.
access_token (str|None): The access token of the newly logged in
device, or None if `inhibit_login` enabled.
bind_email (bool): Whether to bind the email with the identity
server
bind_msisdn (bool): Whether to bind the msisdn with the identity
server
"""
if self.hs.config.worker_app:
yield self._post_registration_client(
user_id=user_id,
auth_result=auth_result,
access_token=access_token,
bind_email=bind_email,
bind_msisdn=bind_msisdn,
)
return
if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
threepid = auth_result[LoginType.EMAIL_IDENTITY]
# Necessary due to auth checks prior to the threepid being
# written to the db
if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid
):
yield self.store.upsert_monthly_active_user(user_id)
yield self._register_email_threepid(
user_id, threepid, access_token,
bind_email,
)
if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
yield self._register_msisdn_threepid(
user_id, threepid, bind_msisdn,
)
if auth_result and LoginType.TERMS in auth_result:
yield self._on_user_consented(
user_id, self.hs.config.user_consent_version,
)
@defer.inlineCallbacks
def _on_user_consented(self, user_id, consent_version):
"""A user consented to the terms on registration
Args:
user_id (str): The user ID that consented
consent_version (str): version of the policy the user has
consented to.
"""
logger.info("%s has consented to the privacy policy", user_id)
yield self.store.user_set_consent_version(
user_id, consent_version,
)
yield self.post_consent_actions(user_id)
@defer.inlineCallbacks
def _register_email_threepid(self, user_id, threepid, token, bind_email):
"""Add an email address as a 3pid identifier
Also adds an email pusher for the email address, if configured in the
HS config
Also optionally binds emails to the given user_id on the identity server
Must be called on master.
Args:
user_id (str): id of user
threepid (object): m.login.email.identity auth response
token (str|None): access_token for the user, or None if not logged
in.
bind_email (bool): true if the client requested the email to be
bound at the identity server
Returns:
defer.Deferred:
"""
reqd = ('medium', 'address', 'validated_at')
if any(x not in threepid for x in reqd):
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
return
yield self._auth_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],
threepid['validated_at'],
)
# And we add an email pusher for them by default, but only
# if email notifications are enabled (so people don't start
# getting mail spam where they weren't before if email
# notifs are set up on a home server)
if (self.hs.config.email_enable_notifs and
self.hs.config.email_notif_for_new_users
and token):
# Pull the ID of the access token back out of the db
# It would really make more sense for this to be passed
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = yield self.store.get_user_by_access_token(
token
)
token_id = user_tuple["token_id"]
yield self.pusher_pool.add_pusher(
user_id=user_id,
access_token=token_id,
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
device_display_name=threepid["address"],
pushkey=threepid["address"],
lang=None, # We don't know a user's language here
data={},
)
if bind_email:
logger.info("bind_email specified: binding")
logger.debug("Binding emails %s to %s" % (
threepid, user_id
))
yield self.identity_handler.bind_threepid(
threepid['threepid_creds'], user_id
)
else:
logger.info("bind_email not specified: not binding email")
@defer.inlineCallbacks
def _register_msisdn_threepid(self, user_id, threepid, bind_msisdn):
"""Add a phone number as a 3pid identifier
Also optionally binds msisdn to the given user_id on the identity server
Must be called on master.
Args:
user_id (str): id of user
threepid (object): m.login.msisdn auth response
token (str): access_token for the user
bind_email (bool): true if the client requested the email to be
bound at the identity server
Returns:
defer.Deferred:
"""
try:
assert_params_in_dict(threepid, ['medium', 'address', 'validated_at'])
except SynapseError as ex:
if ex.errcode == Codes.MISSING_PARAM:
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
defer.returnValue(None)
raise
yield self._auth_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],
threepid['validated_at'],
)
if bind_msisdn:
logger.info("bind_msisdn specified: binding")
logger.debug("Binding msisdn %s to %s", threepid, user_id)
yield self.identity_handler.bind_threepid(
threepid['threepid_creds'], user_id
)
else:
logger.info("bind_msisdn not specified: not binding msisdn")

View file

@ -21,7 +21,7 @@ import math
import string
from collections import OrderedDict
from six import string_types
from six import iteritems, string_types
from twisted.internet import defer
@ -32,10 +32,11 @@ from synapse.api.constants import (
JoinRules,
RoomCreationPreset,
)
from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@ -73,6 +74,372 @@ class RoomCreationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
# linearizer to stop two upgrades happening at once
self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")
@defer.inlineCallbacks
def upgrade_room(self, requester, old_room_id, new_version):
"""Replace a room with a new room with a different version
Args:
requester (synapse.types.Requester): the user requesting the upgrade
old_room_id (unicode): the id of the room to be replaced
new_version (unicode): the new room version to use
Returns:
Deferred[unicode]: the new room id
"""
yield self.ratelimit(requester)
user_id = requester.user.to_string()
with (yield self._upgrade_linearizer.queue(old_room_id)):
# start by allocating a new room id
r = yield self.store.get_room(old_room_id)
if r is None:
raise NotFoundError("Unknown room id %s" % (old_room_id,))
new_room_id = yield self._generate_room_id(
creator_id=user_id, is_public=r["is_public"],
)
logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
# we create and auth the tombstone event before properly creating the new
# room, to check our user has perms in the old room.
tombstone_event, tombstone_context = (
yield self.event_creation_handler.create_event(
requester, {
"type": EventTypes.Tombstone,
"state_key": "",
"room_id": old_room_id,
"sender": user_id,
"content": {
"body": "This room has been replaced",
"replacement_room": new_room_id,
}
},
token_id=requester.access_token_id,
)
)
old_room_version = yield self.store.get_room_version(old_room_id)
yield self.auth.check_from_context(
old_room_version, tombstone_event, tombstone_context,
)
yield self.clone_existing_room(
requester,
old_room_id=old_room_id,
new_room_id=new_room_id,
new_room_version=new_version,
tombstone_event_id=tombstone_event.event_id,
)
# now send the tombstone
yield self.event_creation_handler.send_nonmember_event(
requester, tombstone_event, tombstone_context,
)
old_room_state = yield tombstone_context.get_current_state_ids(self.store)
# update any aliases
yield self._move_aliases_to_new_room(
requester, old_room_id, new_room_id, old_room_state,
)
# and finally, shut down the PLs in the old room, and update them in the new
# room.
yield self._update_upgraded_room_pls(
requester, old_room_id, new_room_id, old_room_state,
)
defer.returnValue(new_room_id)
@defer.inlineCallbacks
def _update_upgraded_room_pls(
self, requester, old_room_id, new_room_id, old_room_state,
):
"""Send updated power levels in both rooms after an upgrade
Args:
requester (synapse.types.Requester): the user requesting the upgrade
old_room_id (unicode): the id of the room to be replaced
new_room_id (unicode): the id of the replacement room
old_room_state (dict[tuple[str, str], str]): the state map for the old room
Returns:
Deferred
"""
old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, ""))
if old_room_pl_event_id is None:
logger.warning(
"Not supported: upgrading a room with no PL event. Not setting PLs "
"in old room.",
)
return
old_room_pl_state = yield self.store.get_event(old_room_pl_event_id)
# we try to stop regular users from speaking by setting the PL required
# to send regular events and invites to 'Moderator' level. That's normally
# 50, but if the default PL in a room is 50 or more, then we set the
# required PL above that.
pl_content = dict(old_room_pl_state.content)
users_default = int(pl_content.get("users_default", 0))
restricted_level = max(users_default + 1, 50)
updated = False
for v in ("invite", "events_default"):
current = int(pl_content.get(v, 0))
if current < restricted_level:
logger.info(
"Setting level for %s in %s to %i (was %i)",
v, old_room_id, restricted_level, current,
)
pl_content[v] = restricted_level
updated = True
else:
logger.info(
"Not setting level for %s (already %i)",
v, current,
)
if updated:
try:
yield self.event_creation_handler.create_and_send_nonmember_event(
requester, {
"type": EventTypes.PowerLevels,
"state_key": '',
"room_id": old_room_id,
"sender": requester.user.to_string(),
"content": pl_content,
}, ratelimit=False,
)
except AuthError as e:
logger.warning("Unable to update PLs in old room: %s", e)
logger.info("Setting correct PLs in new room")
yield self.event_creation_handler.create_and_send_nonmember_event(
requester, {
"type": EventTypes.PowerLevels,
"state_key": '',
"room_id": new_room_id,
"sender": requester.user.to_string(),
"content": old_room_pl_state.content,
}, ratelimit=False,
)
@defer.inlineCallbacks
def clone_existing_room(
self, requester, old_room_id, new_room_id, new_room_version,
tombstone_event_id,
):
"""Populate a new room based on an old room
Args:
requester (synapse.types.Requester): the user requesting the upgrade
old_room_id (unicode): the id of the room to be replaced
new_room_id (unicode): the id to give the new room (should already have been
created with _gemerate_room_id())
new_room_version (unicode): the new room version to use
tombstone_event_id (unicode|str): the ID of the tombstone event in the old
room.
Returns:
Deferred[None]
"""
user_id = requester.user.to_string()
if not self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms")
creation_content = {
"room_version": new_room_version,
"predecessor": {
"room_id": old_room_id,
"event_id": tombstone_event_id,
}
}
# Check if old room was non-federatable
# Get old room's create event
old_room_create_event = yield self.store.get_create_event_for_room(old_room_id)
# Check if the create event specified a non-federatable room
if not old_room_create_event.content.get("m.federate", True):
# If so, mark the new room as non-federatable as well
creation_content["m.federate"] = False
initial_state = dict()
# Replicate relevant room events
types_to_copy = (
(EventTypes.JoinRules, ""),
(EventTypes.Name, ""),
(EventTypes.Topic, ""),
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.GuestAccess, ""),
(EventTypes.RoomAvatar, ""),
(EventTypes.Encryption, ""),
(EventTypes.ServerACL, ""),
)
old_room_state_ids = yield self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types(types_to_copy),
)
# map from event_id to BaseEvent
old_room_state_events = yield self.store.get_events(old_room_state_ids.values())
for k, old_event_id in iteritems(old_room_state_ids):
old_event = old_room_state_events.get(old_event_id)
if old_event:
initial_state[k] = old_event.content
yield self._send_events_for_new_room(
requester,
new_room_id,
# we expect to override all the presets with initial_state, so this is
# somewhat arbitrary.
preset_config=RoomCreationPreset.PRIVATE_CHAT,
invite_list=[],
initial_state=initial_state,
creation_content=creation_content,
)
# Transfer membership events
old_room_member_state_ids = yield self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types([(EventTypes.Member, None)]),
)
# map from event_id to BaseEvent
old_room_member_state_events = yield self.store.get_events(
old_room_member_state_ids.values(),
)
for k, old_event in iteritems(old_room_member_state_events):
# Only transfer ban events
if ("membership" in old_event.content and
old_event.content["membership"] == "ban"):
yield self.room_member_handler.update_membership(
requester,
UserID.from_string(old_event['state_key']),
new_room_id,
"ban",
ratelimit=False,
content=old_event.content,
)
# XXX invites/joins
# XXX 3pid invites
@defer.inlineCallbacks
def _move_aliases_to_new_room(
self, requester, old_room_id, new_room_id, old_room_state,
):
directory_handler = self.hs.get_handlers().directory_handler
aliases = yield self.store.get_aliases_for_room(old_room_id)
# check to see if we have a canonical alias.
canonical_alias = None
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event_id:
canonical_alias_event = yield self.store.get_event(canonical_alias_event_id)
if canonical_alias_event:
canonical_alias = canonical_alias_event.content.get("alias", "")
# first we try to remove the aliases from the old room (we suppress sending
# the room_aliases event until the end).
#
# Note that we'll only be able to remove aliases that (a) aren't owned by an AS,
# and (b) unless the user is a server admin, which the user created.
#
# This is probably correct - given we don't allow such aliases to be deleted
# normally, it would be odd to allow it in the case of doing a room upgrade -
# but it makes the upgrade less effective, and you have to wonder why a room
# admin can't remove aliases that point to that room anyway.
# (cf https://github.com/matrix-org/synapse/issues/2360)
#
removed_aliases = []
for alias_str in aliases:
alias = RoomAlias.from_string(alias_str)
try:
yield directory_handler.delete_association(
requester, alias, send_event=False,
)
removed_aliases.append(alias_str)
except SynapseError as e:
logger.warning(
"Unable to remove alias %s from old room: %s",
alias, e,
)
# if we didn't find any aliases, or couldn't remove anyway, we can skip the rest
# of this.
if not removed_aliases:
return
try:
# this can fail if, for some reason, our user doesn't have perms to send
# m.room.aliases events in the old room (note that we've already checked that
# they have perms to send a tombstone event, so that's not terribly likely).
#
# If that happens, it's regrettable, but we should carry on: it's the same
# as when you remove an alias from the directory normally - it just means that
# the aliases event gets out of sync with the directory
# (cf https://github.com/vector-im/riot-web/issues/2369)
yield directory_handler.send_room_alias_update_event(
requester, old_room_id,
)
except AuthError as e:
logger.warning(
"Failed to send updated alias event on old room: %s", e,
)
# we can now add any aliases we successfully removed to the new room.
for alias in removed_aliases:
try:
yield directory_handler.create_association(
requester, RoomAlias.from_string(alias),
new_room_id, servers=(self.hs.hostname, ),
send_event=False,
)
logger.info("Moved alias %s to new room", alias)
except SynapseError as e:
# I'm not really expecting this to happen, but it could if the spam
# checking module decides it shouldn't, or similar.
logger.error(
"Error adding alias %s to new room: %s",
alias, e,
)
try:
if canonical_alias and (canonical_alias in removed_aliases):
yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
"state_key": "",
"room_id": new_room_id,
"sender": requester.user.to_string(),
"content": {"alias": canonical_alias, },
},
ratelimit=False
)
yield directory_handler.send_room_alias_update_event(
requester, new_room_id,
)
except SynapseError as e:
# again I'm not really expecting this to fail, but if it does, I'd rather
# we returned the new room to the client at this point.
logger.error(
"Unable to send updated alias events in new room: %s", e,
)
@defer.inlineCallbacks
def create_room(self, requester, config, ratelimit=True,
@ -104,7 +471,7 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
self.auth.check_auth_blocking(user_id)
yield self.auth.check_auth_blocking(user_id)
if not self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms")
@ -165,28 +532,7 @@ class RoomCreationHandler(BaseHandler):
visibility = config.get("visibility", None)
is_public = visibility == "public"
# autogen room IDs and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
attempts = 0
room_id = None
while attempts < 5:
try:
random_string = stringutils.random_string(18)
gen_room_id = RoomID(
random_string,
self.hs.hostname,
)
yield self.store.store_room(
room_id=gen_room_id.to_string(),
room_creator_user_id=user_id,
is_public=is_public
)
room_id = gen_room_id.to_string()
break
except StoreError:
attempts += 1
if not room_id:
raise StoreError(500, "Couldn't generate a room ID.")
room_id = yield self._generate_room_id(creator_id=user_id, is_public=is_public)
if room_alias:
directory_handler = self.hs.get_handlers().directory_handler
@ -216,18 +562,15 @@ class RoomCreationHandler(BaseHandler):
# override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version
room_member_handler = self.hs.get_room_member_handler()
yield self._send_events_for_new_room(
requester,
room_id,
room_member_handler,
preset_config=preset_config,
invite_list=invite_list,
initial_state=initial_state,
creation_content=creation_content,
room_alias=room_alias,
power_level_content_override=config.get("power_level_content_override", {}),
power_level_content_override=config.get("power_level_content_override"),
creator_join_profile=creator_join_profile,
)
@ -263,7 +606,7 @@ class RoomCreationHandler(BaseHandler):
if is_direct:
content["is_direct"] = is_direct
yield room_member_handler.update_membership(
yield self.room_member_handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
@ -301,14 +644,13 @@ class RoomCreationHandler(BaseHandler):
self,
creator, # A Requester object.
room_id,
room_member_handler,
preset_config,
invite_list,
initial_state,
creation_content,
room_alias,
power_level_content_override,
creator_join_profile,
room_alias=None,
power_level_content_override=None,
creator_join_profile=None,
):
def create(etype, content, **kwargs):
e = {
@ -324,6 +666,7 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks
def send(etype, content, **kwargs):
event = create(etype, content, **kwargs)
logger.info("Sending %s in new room", etype)
yield self.event_creation_handler.create_and_send_nonmember_event(
creator,
event,
@ -346,7 +689,8 @@ class RoomCreationHandler(BaseHandler):
content=creation_content,
)
yield room_member_handler.update_membership(
logger.info("Sending %s in new room", EventTypes.Member)
yield self.room_member_handler.update_membership(
creator,
creator.user,
room_id,
@ -388,7 +732,8 @@ class RoomCreationHandler(BaseHandler):
for invitee in invite_list:
power_level_content["users"][invitee] = 100
power_level_content.update(power_level_content_override)
if power_level_content_override:
power_level_content.update(power_level_content_override)
yield send(
etype=EventTypes.PowerLevels,
@ -427,6 +772,30 @@ class RoomCreationHandler(BaseHandler):
content=content,
)
@defer.inlineCallbacks
def _generate_room_id(self, creator_id, is_public):
# autogen room IDs and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
attempts = 0
while attempts < 5:
try:
random_string = stringutils.random_string(18)
gen_room_id = RoomID(
random_string,
self.hs.hostname,
).to_string()
if isinstance(gen_room_id, bytes):
gen_room_id = gen_room_id.decode('utf-8')
yield self.store.store_room(
room_id=gen_room_id,
room_creator_user_id=creator_id,
is_public=is_public,
)
defer.returnValue(gen_room_id)
except StoreError:
attempts += 1
raise StoreError(500, "Couldn't generate a room ID.")
class RoomContextHandler(object):
def __init__(self, hs):

View file

@ -75,8 +75,14 @@ class RoomListHandler(BaseHandler):
# We explicitly don't bother caching searches or requests for
# appservice specific lists.
logger.info("Bypassing cache as search request.")
# XXX: Quick hack to stop room directory queries taking too long.
# Timeout request after 60s. Probably want a more fundamental
# solution at some point
timeout = self.clock.time() + 60
return self._get_public_room_list(
limit, since_token, search_filter, network_tuple=network_tuple,
limit, since_token, search_filter,
network_tuple=network_tuple, timeout=timeout,
)
key = (limit, since_token, network_tuple)
@ -91,7 +97,7 @@ class RoomListHandler(BaseHandler):
def _get_public_room_list(self, limit=None, since_token=None,
search_filter=None,
network_tuple=EMPTY_THIRD_PARTY_ID,
from_federation=False,):
from_federation=False, timeout=None,):
if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token)
else:
@ -206,6 +212,9 @@ class RoomListHandler(BaseHandler):
chunk = []
for i in range(0, len(rooms_to_scan), step):
if timeout and self.clock.time() > timeout:
raise Exception("Timed out searching room directory")
batch = rooms_to_scan[i:i + step]
logger.info("Processing %i rooms for result", len(batch))
yield concurrently_execute(

View file

@ -61,9 +61,9 @@ class RoomMemberHandler(object):
self.federation_handler = hs.get_handlers().federation_handler
self.directory_handler = hs.get_handlers().directory_handler
self.registration_handler = hs.get_handlers().registration_handler
self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler()
self.event_creation_hander = hs.get_event_creation_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.member_linearizer = Linearizer(name="member")
@ -161,6 +161,8 @@ class RoomMemberHandler(object):
ratelimit=True,
content=None,
):
user_id = target.to_string()
if content is None:
content = {}
@ -168,14 +170,14 @@ class RoomMemberHandler(object):
if requester.is_guest:
content["kind"] = "guest"
event, context = yield self.event_creation_hander.create_event(
event, context = yield self.event_creation_handler.create_event(
requester,
{
"type": EventTypes.Member,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"state_key": target.to_string(),
"state_key": user_id,
# For backwards compatibility:
"membership": membership,
@ -186,14 +188,14 @@ class RoomMemberHandler(object):
)
# Check if this event matches the previous membership event for the user.
duplicate = yield self.event_creation_hander.deduplicate_state_event(
duplicate = yield self.event_creation_handler.deduplicate_state_event(
event, context,
)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
defer.returnValue(duplicate)
yield self.event_creation_hander.handle_new_client_event(
yield self.event_creation_handler.handle_new_client_event(
requester,
event,
context,
@ -204,12 +206,12 @@ class RoomMemberHandler(object):
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_member_event_id = prev_state_ids.get(
(EventTypes.Member, target.to_string()),
(EventTypes.Member, user_id),
None
)
if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# Only fire user_joined_room if the user has actually joined the
# room. Don't bother if the user is just changing their profile
# info.
newly_joined = True
@ -218,6 +220,18 @@ class RoomMemberHandler(object):
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield self._user_joined_room(target, room_id)
# Copy over direct message status and room tags if this is a join
# on an upgraded room
# Check if this is an upgraded room
predecessor = yield self.store.get_room_predecessor(room_id)
if predecessor:
# It is an upgraded room. Copy over old tags
self.copy_room_tags_and_direct_to_room(
predecessor["room_id"], room_id, user_id,
)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
@ -226,6 +240,55 @@ class RoomMemberHandler(object):
defer.returnValue(event)
@defer.inlineCallbacks
def copy_room_tags_and_direct_to_room(
self,
old_room_id,
new_room_id,
user_id,
):
"""Copies the tags and direct room state from one room to another.
Args:
old_room_id (str)
new_room_id (str)
user_id (str)
Returns:
Deferred[None]
"""
# Retrieve user account data for predecessor room
user_account_data, _ = yield self.store.get_account_data_for_user(
user_id,
)
# Copy direct message state if applicable
direct_rooms = user_account_data.get("m.direct", {})
# Check which key this room is under
if isinstance(direct_rooms, dict):
for key, room_id_list in direct_rooms.items():
if old_room_id in room_id_list and new_room_id not in room_id_list:
# Add new room_id to this key
direct_rooms[key].append(new_room_id)
# Save back to user's m.direct account data
yield self.store.add_account_data_for_user(
user_id, "m.direct", direct_rooms,
)
break
# Copy room tags if applicable
room_tags = yield self.store.get_tags_for_room(
user_id, old_room_id,
)
# Copy each room tag to the new room
for tag, tag_content in room_tags.items():
yield self.store.add_tag_to_room(
user_id, new_room_id, tag, tag_content
)
@defer.inlineCallbacks
def update_membership(
self,
@ -493,7 +556,7 @@ class RoomMemberHandler(object):
else:
requester = synapse.types.create_requester(target_user)
prev_event = yield self.event_creation_hander.deduplicate_state_event(
prev_event = yield self.event_creation_handler.deduplicate_state_event(
event, context,
)
if prev_event is not None:
@ -513,7 +576,7 @@ class RoomMemberHandler(object):
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
yield self.event_creation_hander.handle_new_client_event(
yield self.event_creation_handler.handle_new_client_event(
requester,
event,
context,
@ -527,7 +590,7 @@ class RoomMemberHandler(object):
)
if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# Only fire user_joined_room if the user has actually joined the
# room. Don't bother if the user is just changing their profile
# info.
newly_joined = True
@ -755,7 +818,7 @@ class RoomMemberHandler(object):
)
)
yield self.event_creation_hander.create_and_send_nonmember_event(
yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
@ -877,7 +940,8 @@ class RoomMemberHandler(object):
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))
if len(current_state_ids) == 1 and create_event_id:
defer.returnValue(self.hs.is_mine_id(create_event_id))
# We can only get here if we're in the process of creating the room
defer.returnValue(True)
for etype, state_key in current_state_ids:
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):

View file

@ -37,6 +37,41 @@ class SearchHandler(BaseHandler):
def __init__(self, hs):
super(SearchHandler, self).__init__(hs)
@defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id):
"""Retrieves room IDs of old rooms in the history of an upgraded room.
We do so by checking the m.room.create event of the room for a
`predecessor` key. If it exists, we add the room ID to our return
list and then check that room for a m.room.create event and so on
until we can no longer find any more previous rooms.
The full list of all found rooms in then returned.
Args:
room_id (str): id of the room to search through.
Returns:
Deferred[iterable[unicode]]: predecessor room ids
"""
historical_room_ids = []
while True:
predecessor = yield self.store.get_room_predecessor(room_id)
# If no predecessor, assume we've hit a dead end
if not predecessor:
break
# Add predecessor's room ID
historical_room_ids.append(predecessor["room_id"])
# Scan through the old room for further predecessors
room_id = predecessor["room_id"]
defer.returnValue(historical_room_ids)
@defer.inlineCallbacks
def search(self, user, content, batch=None):
"""Performs a full text search for a user.
@ -50,6 +85,9 @@ class SearchHandler(BaseHandler):
dict to be returned to the client with results of search
"""
if not self.hs.config.enable_search:
raise SynapseError(400, "Search is disabled on this homeserver")
batch_group = None
batch_group_key = None
batch_token = None
@ -134,6 +172,18 @@ class SearchHandler(BaseHandler):
)
room_ids = set(r.room_id for r in rooms)
# If doing a subset of all rooms seearch, check if any of the rooms
# are from an upgraded room, and search their contents as well
if search_filter.rooms:
historical_room_ids = []
for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist
ids = yield self.get_old_rooms_from_upgraded_room(room_id)
historical_room_ids += ids
# Prevent any historical events from being filtered
search_filter = search_filter.with_room_ids(historical_room_ids)
room_ids = search_filter.filter_rooms(room_ids)
if batch_group == "room_id":

View file

@ -895,14 +895,17 @@ class SyncHandler(object):
Returns:
Deferred(SyncResult)
"""
logger.info("Calculating sync response for %r", sync_config.user)
# NB: The now_token gets changed by some of the generate_sync_* methods,
# this is due to some of the underlying streams not supporting the ability
# to query up to a given point.
# Always use the `now_token` in `SyncResultBuilder`
now_token = yield self.event_sources.get_current_token()
logger.info(
"Calculating sync response for %r between %s and %s",
sync_config.user, since_token, now_token,
)
user_id = sync_config.user.to_string()
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service:
@ -1390,6 +1393,12 @@ class SyncHandler(object):
room_entries = []
invited = []
for room_id, events in iteritems(mem_change_events_by_room_id):
logger.info(
"Membership changes in %s: [%s]",
room_id,
", ".join(("%s (%s)" % (e.event_id, e.membership) for e in events)),
)
non_joins = [e for e in events if e.membership != Membership.JOIN]
has_join = len(non_joins) != len(events)
@ -1473,10 +1482,22 @@ class SyncHandler(object):
if since_token and since_token.is_after(leave_token):
continue
# If this is an out of band message, like a remote invite
# rejection, we include it in the recents batch. Otherwise, we
# let _load_filtered_recents handle fetching the correct
# batches.
#
# This is all screaming out for a refactor, as the logic here is
# subtle and the moving parts numerous.
if leave_event.internal_metadata.is_out_of_band_membership():
batch_events = [leave_event]
else:
batch_events = None
room_entries.append(RoomSyncResultBuilder(
room_id=room_id,
rtype="archived",
events=None,
events=batch_events,
newly_joined=room_id in newly_joined_rooms,
full_state=False,
since_token=since_token,
@ -1668,13 +1689,17 @@ class SyncHandler(object):
"content": content,
})
account_data = sync_config.filter_collection.filter_room_account_data(
account_data_events = sync_config.filter_collection.filter_room_account_data(
account_data_events
)
ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral)
if not (always_include or batch or account_data or ephemeral or full_state):
if not (always_include
or batch
or account_data_events
or ephemeral
or full_state):
return
state = yield self.compute_state_delta(
@ -1745,7 +1770,7 @@ class SyncHandler(object):
room_id=room_id,
timeline=batch,
state=state,
account_data=account_data,
account_data=account_data_events,
)
if room_sync or always_include:
sync_result_builder.archived.append(room_sync)

View file

@ -63,11 +63,8 @@ class TypingHandler(object):
self._member_typing_until = {} # clock time we expect to stop
self._member_last_federation_poke = {}
# map room IDs to serial numbers
self._room_serials = {}
self._latest_room_serial = 0
# map room IDs to sets of users currently typing
self._room_typing = {}
self._reset()
# caches which room_ids changed at which serials
self._typing_stream_change_cache = StreamChangeCache(
@ -79,6 +76,15 @@ class TypingHandler(object):
5000,
)
def _reset(self):
"""
Reset the typing handler's data caches.
"""
# map room IDs to serial numbers
self._room_serials = {}
# map room IDs to sets of users currently typing
self._room_typing = {}
def _handle_timeouts(self):
logger.info("Checking for typing timeouts")

View file

@ -19,6 +19,7 @@ from six import iteritems
from twisted.internet import defer
import synapse.metrics
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo
@ -125,9 +126,12 @@ class UserDirectoryHandler(object):
"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
yield self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url, None,
)
is_support = yield self.store.is_support_user(user_id)
# Support users are for diagnostics and should not appear in the user directory.
if not is_support:
yield self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url, None
)
@defer.inlineCallbacks
def handle_user_deactivated(self, user_id):
@ -160,6 +164,12 @@ class UserDirectoryHandler(object):
yield self._handle_deltas(deltas)
self.pos = deltas[-1]["stream_id"]
# Expose current event processing position to prometheus
synapse.metrics.event_processing_positions.labels("user_dir").set(
self.pos
)
yield self.store.update_user_directory_stream_pos(self.pos)
@defer.inlineCallbacks
@ -182,21 +192,25 @@ class UserDirectoryHandler(object):
logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids))
yield self._handle_initial_room(room_id)
num_processed_rooms += 1
yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.0)
logger.info("Processed all rooms.")
if self.search_all_users:
num_processed_users = 0
user_ids = yield self.store.get_all_local_users()
logger.info("Doing initial update of user directory. %d users", len(user_ids))
logger.info(
"Doing initial update of user directory. %d users", len(user_ids)
)
for user_id in user_ids:
# We add profiles for all users even if they don't match the
# include pattern, just in case we want to change it in future
logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids))
logger.info(
"Handling user %d/%d", num_processed_users + 1, len(user_ids)
)
yield self._handle_local_user(user_id)
num_processed_users += 1
yield self.clock.sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
yield self.clock.sleep(self.INITIAL_USER_SLEEP_MS / 1000.0)
logger.info("Processed all users")
@ -215,24 +229,24 @@ class UserDirectoryHandler(object):
if not is_in_room:
return
is_public = yield self.store.is_room_world_readable_or_publicly_joinable(room_id)
is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
room_id
)
users_with_profile = yield self.state.get_current_user_in_room(room_id)
user_ids = set(users_with_profile)
unhandled_users = user_ids - self.initially_handled_users
yield self.store.add_profiles_to_user_dir(
room_id, {
user_id: users_with_profile[user_id] for user_id in unhandled_users
}
room_id,
{user_id: users_with_profile[user_id] for user_id in unhandled_users},
)
self.initially_handled_users |= unhandled_users
if is_public:
yield self.store.add_users_to_public_room(
room_id,
user_ids=user_ids - self.initially_handled_users_in_public
room_id, user_ids=user_ids - self.initially_handled_users_in_public
)
self.initially_handled_users_in_public |= user_ids
@ -244,7 +258,7 @@ class UserDirectoryHandler(object):
count = 0
for user_id in user_ids:
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.0)
if not self.is_mine_id(user_id):
count += 1
@ -259,7 +273,7 @@ class UserDirectoryHandler(object):
continue
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.0)
count += 1
user_set = (user_id, other_user_id)
@ -281,25 +295,23 @@ class UserDirectoryHandler(object):
if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
yield self.store.add_users_who_share_room(
room_id, not is_public, to_insert,
room_id, not is_public, to_insert
)
to_insert.clear()
if len(to_update) > self.INITIAL_ROOM_BATCH_SIZE:
yield self.store.update_users_who_share_room(
room_id, not is_public, to_update,
room_id, not is_public, to_update
)
to_update.clear()
if to_insert:
yield self.store.add_users_who_share_room(
room_id, not is_public, to_insert,
)
yield self.store.add_users_who_share_room(room_id, not is_public, to_insert)
to_insert.clear()
if to_update:
yield self.store.update_users_who_share_room(
room_id, not is_public, to_update,
room_id, not is_public, to_update
)
to_update.clear()
@ -320,50 +332,55 @@ class UserDirectoryHandler(object):
# may have become public or not and add/remove the users in said room
if typ in (EventTypes.RoomHistoryVisibility, EventTypes.JoinRules):
yield self._handle_room_publicity_change(
room_id, prev_event_id, event_id, typ,
room_id, prev_event_id, event_id, typ
)
elif typ == EventTypes.Member:
change = yield self._get_key_change(
prev_event_id, event_id,
prev_event_id,
event_id,
key_name="membership",
public_value=Membership.JOIN,
)
if change is None:
# Handle any profile changes
yield self._handle_profile_change(
state_key, room_id, prev_event_id, event_id,
)
continue
if not change:
if change is False:
# Need to check if the server left the room entirely, if so
# we might need to remove all the users in that room
is_in_room = yield self.store.is_host_joined(
room_id, self.server_name,
room_id, self.server_name
)
if not is_in_room:
logger.info("Server left room: %r", room_id)
# Fetch all the users that we marked as being in user
# directory due to being in the room and then check if
# need to remove those users or not
user_ids = yield self.store.get_users_in_dir_due_to_room(room_id)
user_ids = yield self.store.get_users_in_dir_due_to_room(
room_id
)
for user_id in user_ids:
yield self._handle_remove_user(room_id, user_id)
return
else:
logger.debug("Server is still in room: %r", room_id)
if change: # The user joined
event = yield self.store.get_event(event_id, allow_none=True)
profile = ProfileInfo(
avatar_url=event.content.get("avatar_url"),
display_name=event.content.get("displayname"),
)
is_support = yield self.store.is_support_user(state_key)
if not is_support:
if change is None:
# Handle any profile changes
yield self._handle_profile_change(
state_key, room_id, prev_event_id, event_id
)
continue
yield self._handle_new_user(room_id, state_key, profile)
else: # The user left
yield self._handle_remove_user(room_id, state_key)
if change: # The user joined
event = yield self.store.get_event(event_id, allow_none=True)
profile = ProfileInfo(
avatar_url=event.content.get("avatar_url"),
display_name=event.content.get("displayname"),
)
yield self._handle_new_user(room_id, state_key, profile)
else: # The user left
yield self._handle_remove_user(room_id, state_key)
else:
logger.debug("Ignoring irrelevant type: %r", typ)
@ -382,13 +399,15 @@ class UserDirectoryHandler(object):
if typ == EventTypes.RoomHistoryVisibility:
change = yield self._get_key_change(
prev_event_id, event_id,
prev_event_id,
event_id,
key_name="history_visibility",
public_value="world_readable",
)
elif typ == EventTypes.JoinRules:
change = yield self._get_key_change(
prev_event_id, event_id,
prev_event_id,
event_id,
key_name="join_rule",
public_value=JoinRules.PUBLIC,
)
@ -513,7 +532,7 @@ class UserDirectoryHandler(object):
)
if self.is_mine_id(other_user_id) and not is_appservice:
shared_is_private = yield self.store.get_if_users_share_a_room(
other_user_id, user_id,
other_user_id, user_id
)
if shared_is_private is True:
# We've already marked in the database they share a private room
@ -528,13 +547,11 @@ class UserDirectoryHandler(object):
to_insert.add((other_user_id, user_id))
if to_insert:
yield self.store.add_users_who_share_room(
room_id, not is_public, to_insert,
)
yield self.store.add_users_who_share_room(room_id, not is_public, to_insert)
if to_update:
yield self.store.update_users_who_share_room(
room_id, not is_public, to_update,
room_id, not is_public, to_update
)
@defer.inlineCallbacks
@ -553,15 +570,15 @@ class UserDirectoryHandler(object):
row = yield self.store.get_user_in_public_room(user_id)
update_user_in_public = row and row["room_id"] == room_id
if (update_user_in_public or update_user_dir):
if update_user_in_public or update_user_dir:
# XXX: Make this faster?
rooms = yield self.store.get_rooms_for_user(user_id)
for j_room_id in rooms:
if (not update_user_in_public and not update_user_dir):
if not update_user_in_public and not update_user_dir:
break
is_in_room = yield self.store.is_host_joined(
j_room_id, self.server_name,
j_room_id, self.server_name
)
if not is_in_room:
@ -589,19 +606,19 @@ class UserDirectoryHandler(object):
# Get a list of user tuples that were in the DB due to this room and
# users (this includes tuples where the other user matches `user_id`)
user_tuples = yield self.store.get_users_in_share_dir_with_room_id(
user_id, room_id,
user_id, room_id
)
for user_id, other_user_id in user_tuples:
# For each user tuple get a list of rooms that they still share,
# trying to find a private room, and update the entry in the DB
rooms = yield self.store.get_rooms_in_common_for_users(user_id, other_user_id)
rooms = yield self.store.get_rooms_in_common_for_users(
user_id, other_user_id
)
# If they dont share a room anymore, remove the mapping
if not rooms:
yield self.store.remove_user_who_share_room(
user_id, other_user_id,
)
yield self.store.remove_user_who_share_room(user_id, other_user_id)
continue
found_public_share = None
@ -615,13 +632,13 @@ class UserDirectoryHandler(object):
else:
found_public_share = None
yield self.store.update_users_who_share_room(
room_id, not is_public, [(user_id, other_user_id)],
room_id, not is_public, [(user_id, other_user_id)]
)
break
if found_public_share:
yield self.store.update_users_who_share_room(
room_id, not is_public, [(user_id, other_user_id)],
room_id, not is_public, [(user_id, other_user_id)]
)
@defer.inlineCallbacks
@ -649,7 +666,7 @@ class UserDirectoryHandler(object):
if prev_name != new_name or prev_avatar != new_avatar:
yield self.store.update_profile_in_user_dir(
user_id, new_name, new_avatar, room_id,
user_id, new_name, new_avatar, room_id
)
@defer.inlineCallbacks

View file

@ -15,8 +15,10 @@
# limitations under the License.
import re
from twisted.internet import task
from twisted.internet.defer import CancelledError
from twisted.python import failure
from twisted.web.client import FileBodyProducer
from synapse.api.errors import SynapseError
@ -47,3 +49,16 @@ def redact_uri(uri):
r'\1<redacted>\3',
uri
)
class QuieterFileBodyProducer(FileBodyProducer):
"""Wrapper for FileBodyProducer that avoids CRITICAL errors when the connection drops.
Workaround for https://github.com/matrix-org/synapse/issues/4003 /
https://twistedmatrix.com/trac/ticket/6528
"""
def stopProducing(self):
try:
FileBodyProducer.stopProducing(self)
except task.TaskStopped:
pass

View file

@ -15,34 +15,36 @@
# limitations under the License.
import logging
from io import BytesIO
from six import text_type
from six.moves import urllib
import treq
from canonicaljson import encode_canonical_json, json
from netaddr import IPAddress
from prometheus_client import Counter
from zope.interface import implementer, provider
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, protocol, reactor, ssl
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web._newclient import ResponseDone
from twisted.web.client import (
Agent,
BrowserLikeRedirectAgent,
ContentDecoderAgent,
GzipDecoder,
HTTPConnectionPool,
PartialDownloadError,
readBody,
from twisted.internet import defer, protocol, ssl
from twisted.internet.interfaces import (
IReactorPluggableNameResolver,
IResolutionReceiver,
)
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, HTTPConnectionPool, PartialDownloadError, readBody
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import cancelled_to_request_timed_out_error, redact_uri
from synapse.http.endpoint import SpiderEndpoint
from synapse.http import (
QuieterFileBodyProducer,
cancelled_to_request_timed_out_error,
redact_uri,
)
from synapse.util.async_helpers import timeout_deferred
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.logcontext import make_deferred_yieldable
@ -50,8 +52,125 @@ from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__)
outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
incoming_responses_counter = Counter("synapse_http_client_responses", "",
["method", "code"])
incoming_responses_counter = Counter(
"synapse_http_client_responses", "", ["method", "code"]
)
def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
"""
Args:
ip_address (netaddr.IPAddress)
ip_whitelist (netaddr.IPSet)
ip_blacklist (netaddr.IPSet)
"""
if ip_address in ip_blacklist:
if ip_whitelist is None or ip_address not in ip_whitelist:
return True
return False
class IPBlacklistingResolver(object):
"""
A proxy for reactor.nameResolver which only produces non-blacklisted IP
addresses, preventing DNS rebinding attacks on URL preview.
"""
def __init__(self, reactor, ip_whitelist, ip_blacklist):
"""
Args:
reactor (twisted.internet.reactor)
ip_whitelist (netaddr.IPSet)
ip_blacklist (netaddr.IPSet)
"""
self._reactor = reactor
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
def resolveHostName(self, recv, hostname, portNumber=0):
r = recv()
d = defer.Deferred()
addresses = []
@provider(IResolutionReceiver)
class EndpointReceiver(object):
@staticmethod
def resolutionBegan(resolutionInProgress):
pass
@staticmethod
def addressResolved(address):
ip_address = IPAddress(address.host)
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info(
"Dropped %s from DNS resolution to %s" % (ip_address, hostname)
)
raise SynapseError(403, "IP address blocked by IP blacklist entry")
addresses.append(address)
@staticmethod
def resolutionComplete():
d.callback(addresses)
self._reactor.nameResolver.resolveHostName(
EndpointReceiver, hostname, portNumber=portNumber
)
def _callback(addrs):
r.resolutionBegan(None)
for i in addrs:
r.addressResolved(i)
r.resolutionComplete()
d.addCallback(_callback)
return r
class BlacklistingAgentWrapper(Agent):
"""
An Agent wrapper which will prevent access to IP addresses being accessed
directly (without an IP address lookup).
"""
def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None):
"""
Args:
agent (twisted.web.client.Agent): The Agent to wrap.
reactor (twisted.internet.reactor)
ip_whitelist (netaddr.IPSet)
ip_blacklist (netaddr.IPSet)
"""
self._agent = agent
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
def request(self, method, uri, headers=None, bodyProducer=None):
h = urllib.parse.urlparse(uri.decode('ascii'))
try:
ip_address = IPAddress(h.hostname)
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info(
"Blocking access to %s because of blacklist" % (ip_address,)
)
e = SynapseError(403, "IP address blocked by IP blacklist entry")
return defer.fail(Failure(e))
except Exception:
# Not an IP
pass
return self._agent.request(
method, uri, headers=headers, bodyProducer=bodyProducer
)
class SimpleHttpClient(object):
@ -59,14 +178,54 @@ class SimpleHttpClient(object):
A simple, no-frills HTTP client with methods that wrap up common ways of
using HTTP in Matrix
"""
def __init__(self, hs):
def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None):
"""
Args:
hs (synapse.server.HomeServer)
treq_args (dict): Extra keyword arguments to be given to treq.request.
ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that
we may not request.
ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
"""
self.hs = hs
pool = HTTPConnectionPool(reactor)
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
self._extra_treq_args = treq_args
self.user_agent = hs.version_string
self.clock = hs.get_clock()
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
self.user_agent = self.user_agent.encode('ascii')
if self._ip_blacklist:
real_reactor = hs.get_reactor()
# If we have an IP blacklist, we need to use a DNS resolver which
# filters out blacklisted IP addresses, to prevent DNS rebinding.
nameResolver = IPBlacklistingResolver(
real_reactor, self._ip_whitelist, self._ip_blacklist
)
@implementer(IReactorPluggableNameResolver)
class Reactor(object):
def __getattr__(_self, attr):
if attr == "nameResolver":
return nameResolver
else:
return getattr(real_reactor, attr)
self.reactor = Reactor()
else:
self.reactor = hs.get_reactor()
# the pusher makes lots of concurrent SSL connections to sygnal, and
# tends to do so in batches, so we need to allow the pool to keep lots
# of idle connections around.
# tends to do so in batches, so we need to allow the pool to keep
# lots of idle connections around.
pool = HTTPConnectionPool(self.reactor)
pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
pool.cachedConnectionTimeout = 2 * 60
@ -74,20 +233,35 @@ class SimpleHttpClient(object):
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
self.agent = Agent(
reactor,
self.reactor,
connectTimeout=15,
contextFactory=hs.get_http_client_context_factory(),
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
)
self.user_agent = hs.version_string
self.clock = hs.get_clock()
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,)
self.user_agent = self.user_agent.encode('ascii')
if self._ip_blacklist:
# If we have an IP blacklist, we then install the blacklisting Agent
# which prevents direct access to IP addresses, that are not caught
# by the DNS resolution.
self.agent = BlacklistingAgentWrapper(
self.agent,
self.reactor,
ip_whitelist=self._ip_whitelist,
ip_blacklist=self._ip_blacklist,
)
@defer.inlineCallbacks
def request(self, method, uri, data=b'', headers=None):
def request(self, method, uri, data=None, headers=None):
"""
Args:
method (str): HTTP method to use.
uri (str): URI to query.
data (bytes): Data to send in the request body, if applicable.
headers (t.w.http_headers.Headers): Request headers.
Raises:
SynapseError: If the IP is blacklisted.
"""
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
outgoing_requests_counter.labels(method).inc()
@ -96,26 +270,39 @@ class SimpleHttpClient(object):
logger.info("Sending request %s %s", method, redact_uri(uri))
try:
body_producer = None
if data is not None:
body_producer = QuieterFileBodyProducer(BytesIO(data))
request_deferred = treq.request(
method, uri, agent=self.agent, data=data, headers=headers
method,
uri,
agent=self.agent,
data=body_producer,
headers=headers,
**self._extra_treq_args
)
request_deferred = timeout_deferred(
request_deferred, 60, self.hs.get_reactor(),
request_deferred,
60,
self.hs.get_reactor(),
cancelled_to_request_timed_out_error,
)
response = yield make_deferred_yieldable(request_deferred)
incoming_responses_counter.labels(method, response.code).inc()
logger.info(
"Received response to %s %s: %s",
method, redact_uri(uri), response.code
"Received response to %s %s: %s", method, redact_uri(uri), response.code
)
defer.returnValue(response)
except Exception as e:
incoming_responses_counter.labels(method, "ERR").inc()
logger.info(
"Error sending request to %s %s: %s %s",
method, redact_uri(uri), type(e).__name__, e.args[0]
method,
redact_uri(uri),
type(e).__name__,
e.args[0],
)
raise
@ -140,8 +327,9 @@ class SimpleHttpClient(object):
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.parse.urlencode(
encode_urlencode_args(args), True).encode("utf8")
query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode(
"utf8"
)
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
@ -151,15 +339,13 @@ class SimpleHttpClient(object):
actual_headers.update(headers)
response = yield self.request(
"POST",
uri,
headers=Headers(actual_headers),
data=query_bytes
"POST", uri, headers=Headers(actual_headers), data=query_bytes
)
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
body = yield make_deferred_yieldable(treq.json_content(response))
defer.returnValue(body)
defer.returnValue(json.loads(body))
else:
raise HttpResponseException(response.code, response.phrase, body)
@ -193,10 +379,7 @@ class SimpleHttpClient(object):
actual_headers.update(headers)
response = yield self.request(
"POST",
uri,
headers=Headers(actual_headers),
data=json_str
"POST", uri, headers=Headers(actual_headers), data=json_str
)
body = yield make_deferred_yieldable(readBody(response))
@ -264,10 +447,7 @@ class SimpleHttpClient(object):
actual_headers.update(headers)
response = yield self.request(
"PUT",
uri,
headers=Headers(actual_headers),
data=json_str
"PUT", uri, headers=Headers(actual_headers), data=json_str
)
body = yield make_deferred_yieldable(readBody(response))
@ -299,17 +479,11 @@ class SimpleHttpClient(object):
query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
actual_headers = {
b"User-Agent": [self.user_agent],
}
actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
actual_headers.update(headers)
response = yield self.request(
"GET",
uri,
headers=Headers(actual_headers),
)
response = yield self.request("GET", uri, headers=Headers(actual_headers))
body = yield make_deferred_yieldable(readBody(response))
@ -334,22 +508,18 @@ class SimpleHttpClient(object):
headers, absolute URI of the response and HTTP response code.
"""
actual_headers = {
b"User-Agent": [self.user_agent],
}
actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
actual_headers.update(headers)
response = yield self.request(
"GET",
url,
headers=Headers(actual_headers),
)
response = yield self.request("GET", url, headers=Headers(actual_headers))
resp_headers = dict(response.headers.getAllRawHeaders())
if (b'Content-Length' in resp_headers and
int(resp_headers[b'Content-Length']) > max_size):
if (
b'Content-Length' in resp_headers
and int(resp_headers[b'Content-Length'][0]) > max_size
):
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
502,
@ -359,26 +529,20 @@ class SimpleHttpClient(object):
if response.code > 299:
logger.warn("Got %d when downloading %s" % (response.code, url))
raise SynapseError(
502,
"Got error %d" % (response.code,),
Codes.UNKNOWN,
)
raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
# TODO: if our Content-Type is HTML or something, just read the first
# N bytes into RAM rather than saving it all to disk only to read it
# straight back in again
try:
length = yield make_deferred_yieldable(_readBodyToFile(
response, output_stream, max_size,
))
length = yield make_deferred_yieldable(
_readBodyToFile(response, output_stream, max_size)
)
except Exception as e:
logger.exception("Failed to download body")
raise SynapseError(
502,
("Failed to download remote body: %s" % e),
Codes.UNKNOWN,
502, ("Failed to download remote body: %s" % e), Codes.UNKNOWN
)
defer.returnValue(
@ -387,13 +551,14 @@ class SimpleHttpClient(object):
resp_headers,
response.request.absoluteURI.decode('ascii'),
response.code,
),
)
)
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size):
self.stream = stream
@ -405,11 +570,13 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback(SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
Codes.TOO_LARGE,
))
self.deferred.errback(
SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
Codes.TOO_LARGE,
)
)
self.deferred = defer.Deferred()
self.transport.loseConnection()
@ -427,6 +594,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
def _readBodyToFile(response, stream, max_size):
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
@ -449,10 +617,12 @@ class CaptchaServerHttpClient(SimpleHttpClient):
"POST",
url,
data=query_bytes,
headers=Headers({
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
})
headers=Headers(
{
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
}
),
)
try:
@ -463,57 +633,6 @@ class CaptchaServerHttpClient(SimpleHttpClient):
defer.returnValue(e.response)
class SpiderEndpointFactory(object):
def __init__(self, hs):
self.blacklist = hs.config.url_preview_ip_range_blacklist
self.whitelist = hs.config.url_preview_ip_range_whitelist
self.policyForHTTPS = hs.get_http_client_context_factory()
def endpointForURI(self, uri):
logger.info("Getting endpoint for %s", uri.toBytes())
if uri.scheme == b"http":
endpoint_factory = HostnameEndpoint
elif uri.scheme == b"https":
tlsCreator = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port)
def endpoint_factory(reactor, host, port, **kw):
return wrapClientTLS(
tlsCreator,
HostnameEndpoint(reactor, host, port, **kw))
else:
logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme)
return None
return SpiderEndpoint(
reactor, uri.host, uri.port, self.blacklist, self.whitelist,
endpoint=endpoint_factory, endpoint_kw_args=dict(timeout=15),
)
class SpiderHttpClient(SimpleHttpClient):
"""
Separate HTTP client for spidering arbitrary URLs.
Special in that it follows retries and has a UA that looks
like a browser.
used by the preview_url endpoint in the content repo.
"""
def __init__(self, hs):
SimpleHttpClient.__init__(self, hs)
# clobber the base class's agent and UA:
self.agent = ContentDecoderAgent(
BrowserLikeRedirectAgent(
Agent.usingEndpointFactory(
reactor,
SpiderEndpointFactory(hs)
)
), [(b'gzip', GzipDecoder)]
)
# We could look like Chrome:
# self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko)
# Chrome Safari" % hs.version_string)
def encode_urlencode_args(args):
return {k: encode_urlencode_arg(v) for k, v in args.items()}

View file

@ -12,30 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
import random
import re
import time
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.error import ConnectError
from twisted.names import client, dns
from twisted.names.error import DNSNameError, DomainError
logger = logging.getLogger(__name__)
SERVER_CACHE = {}
# our record of an individual server which can be tried to reach a destination.
#
# "host" is the hostname acquired from the SRV record. Except when there's
# no SRV record, in which case it is the original hostname.
_Server = collections.namedtuple(
"_Server", "priority weight host port expires"
)
def parse_server_name(server_name):
"""Split a server name into host/port parts.
@ -100,299 +81,3 @@ def parse_and_validate_server_name(server_name):
))
return host, port
def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None,
timeout=None):
"""Construct an endpoint for the given matrix destination.
Args:
reactor: Twisted reactor.
destination (unicode): The name of the server to connect to.
tls_client_options_factory
(synapse.crypto.context_factory.ClientTLSOptionsFactory):
Factory which generates TLS options for client connections.
timeout (int): connection timeout in seconds
"""
domain, port = parse_server_name(destination)
endpoint_kw_args = {}
if timeout is not None:
endpoint_kw_args.update(timeout=timeout)
if tls_client_options_factory is None:
transport_endpoint = HostnameEndpoint
default_port = 8008
else:
# the SNI string should be the same as the Host header, minus the port.
# as per https://github.com/matrix-org/synapse/issues/2525#issuecomment-336896777,
# the Host header and SNI should therefore be the server_name of the remote
# server.
tls_options = tls_client_options_factory.get_options(domain)
def transport_endpoint(reactor, host, port, timeout):
return wrapClientTLS(
tls_options,
HostnameEndpoint(reactor, host, port, timeout=timeout),
)
default_port = 8448
if port is None:
return _WrappingEndpointFac(SRVClientEndpoint(
reactor, "matrix", domain, protocol="tcp",
default_port=default_port, endpoint=transport_endpoint,
endpoint_kw_args=endpoint_kw_args
), reactor)
else:
return _WrappingEndpointFac(transport_endpoint(
reactor, domain, port, **endpoint_kw_args
), reactor)
class _WrappingEndpointFac(object):
def __init__(self, endpoint_fac, reactor):
self.endpoint_fac = endpoint_fac
self.reactor = reactor
@defer.inlineCallbacks
def connect(self, protocolFactory):
conn = yield self.endpoint_fac.connect(protocolFactory)
conn = _WrappedConnection(conn, self.reactor)
defer.returnValue(conn)
class _WrappedConnection(object):
"""Wraps a connection and calls abort on it if it hasn't seen any action
for 2.5-3 minutes.
"""
__slots__ = ["conn", "last_request"]
def __init__(self, conn, reactor):
object.__setattr__(self, "conn", conn)
object.__setattr__(self, "last_request", time.time())
self._reactor = reactor
def __getattr__(self, name):
return getattr(self.conn, name)
def __setattr__(self, name, value):
setattr(self.conn, name, value)
def _time_things_out_maybe(self):
# We use a slightly shorter timeout here just in case the callLater is
# triggered early. Paranoia ftw.
# TODO: Cancel the previous callLater rather than comparing time.time()?
if time.time() - self.last_request >= 2.5 * 60:
self.abort()
# Abort the underlying TLS connection. The abort() method calls
# loseConnection() on the TLS connection which tries to
# shutdown the connection cleanly. We call abortConnection()
# since that will promptly close the TLS connection.
#
# In Twisted >18.4; the TLS connection will be None if it has closed
# which will make abortConnection() throw. Check that the TLS connection
# is not None before trying to close it.
if self.transport.getHandle() is not None:
self.transport.abortConnection()
def request(self, request):
self.last_request = time.time()
# Time this connection out if we haven't send a request in the last
# N minutes
# TODO: Cancel the previous callLater?
self._reactor.callLater(3 * 60, self._time_things_out_maybe)
d = self.conn.request(request)
def update_request_time(res):
self.last_request = time.time()
# TODO: Cancel the previous callLater?
self._reactor.callLater(3 * 60, self._time_things_out_maybe)
return res
d.addCallback(update_request_time)
return d
class SpiderEndpoint(object):
"""An endpoint which refuses to connect to blacklisted IP addresses
Implements twisted.internet.interfaces.IStreamClientEndpoint.
"""
def __init__(self, reactor, host, port, blacklist, whitelist,
endpoint=HostnameEndpoint, endpoint_kw_args={}):
self.reactor = reactor
self.host = host
self.port = port
self.blacklist = blacklist
self.whitelist = whitelist
self.endpoint = endpoint
self.endpoint_kw_args = endpoint_kw_args
@defer.inlineCallbacks
def connect(self, protocolFactory):
address = yield self.reactor.resolve(self.host)
from netaddr import IPAddress
ip_address = IPAddress(address)
if ip_address in self.blacklist:
if self.whitelist is None or ip_address not in self.whitelist:
raise ConnectError(
"Refusing to spider blacklisted IP address %s" % address
)
logger.info("Connecting to %s:%s", address, self.port)
endpoint = self.endpoint(
self.reactor, address, self.port, **self.endpoint_kw_args
)
connection = yield endpoint.connect(protocolFactory)
defer.returnValue(connection)
class SRVClientEndpoint(object):
"""An endpoint which looks up SRV records for a service.
Cycles through the list of servers starting with each call to connect
picking the next server.
Implements twisted.internet.interfaces.IStreamClientEndpoint.
"""
def __init__(self, reactor, service, domain, protocol="tcp",
default_port=None, endpoint=HostnameEndpoint,
endpoint_kw_args={}):
self.reactor = reactor
self.service_name = "_%s._%s.%s" % (service, protocol, domain)
if default_port is not None:
self.default_server = _Server(
host=domain,
port=default_port,
priority=0,
weight=0,
expires=0,
)
else:
self.default_server = None
self.endpoint = endpoint
self.endpoint_kw_args = endpoint_kw_args
self.servers = None
self.used_servers = None
@defer.inlineCallbacks
def fetch_servers(self):
self.used_servers = []
self.servers = yield resolve_service(self.service_name)
def pick_server(self):
if not self.servers:
if self.used_servers:
self.servers = self.used_servers
self.used_servers = []
self.servers.sort()
elif self.default_server:
return self.default_server
else:
raise ConnectError(
"No server available for %s" % self.service_name
)
# look for all servers with the same priority
min_priority = self.servers[0].priority
weight_indexes = list(
(index, server.weight + 1)
for index, server in enumerate(self.servers)
if server.priority == min_priority
)
total_weight = sum(weight for index, weight in weight_indexes)
target_weight = random.randint(0, total_weight)
for index, weight in weight_indexes:
target_weight -= weight
if target_weight <= 0:
server = self.servers[index]
# XXX: this looks totally dubious:
#
# (a) we never reuse a server until we have been through
# all of the servers at the same priority, so if the
# weights are A: 100, B:1, we always do ABABAB instead of
# AAAA...AAAB (approximately).
#
# (b) After using all the servers at the lowest priority,
# we move onto the next priority. We should only use the
# second priority if servers at the top priority are
# unreachable.
#
del self.servers[index]
self.used_servers.append(server)
return server
@defer.inlineCallbacks
def connect(self, protocolFactory):
if self.servers is None:
yield self.fetch_servers()
server = self.pick_server()
logger.info("Connecting to %s:%s", server.host, server.port)
endpoint = self.endpoint(
self.reactor, server.host, server.port, **self.endpoint_kw_args
)
connection = yield endpoint.connect(protocolFactory)
defer.returnValue(connection)
@defer.inlineCallbacks
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
cache_entry = cache.get(service_name, None)
if cache_entry:
if all(s.expires > int(clock.time()) for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)
servers = []
try:
try:
answers, _, _ = yield dns_client.lookupService(service_name)
except DNSNameError:
defer.returnValue([])
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name)
for answer in answers:
if answer.type != dns.SRV or not answer.payload:
continue
payload = answer.payload
servers.append(_Server(
host=str(payload.target),
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight),
expires=int(clock.time()) + answer.ttl,
))
servers.sort()
cache[service_name] = list(servers)
except DomainError as e:
# We failed to resolve the name (other than a NameError)
# Try something in the cache, else rereaise
cache_entry = cache.get(service_name, None)
if cache_entry:
logger.warn(
"Failed to resolve %r, falling back to cache. %r",
service_name, e
)
servers = list(cache_entry)
else:
raise e
defer.returnValue(servers)

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View file

@ -0,0 +1,452 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import random
import time
import attr
from netaddr import IPAddress
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, readBody
from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
from synapse.util import Clock
from synapse.util.caches.ttlcache import TTLCache
from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.metrics import Measure
# period to cache .well-known results for by default
WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600
# jitter to add to the .well-known default cache ttl
WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60
# period to cache failure to fetch .well-known for
WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
# cap for .well-known cache period
WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
logger = logging.getLogger(__name__)
well_known_cache = TTLCache('well-known')
@implementer(IAgent)
class MatrixFederationAgent(object):
"""An Agent-like thing which provides a `request` method which will look up a matrix
server and send an HTTP request to it.
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
Args:
reactor (IReactor): twisted reactor to use for underlying requests
tls_client_options_factory (ClientTLSOptionsFactory|None):
factory to use for fetching client tls options, or none to disable TLS.
_well_known_tls_policy (IPolicyForHTTPS|None):
TLS policy to use for fetching .well-known files. None to use a default
(browser-like) implementation.
srv_resolver (SrvResolver|None):
SRVResolver impl to use for looking up SRV records. None to use a default
implementation.
"""
def __init__(
self, reactor, tls_client_options_factory,
_well_known_tls_policy=None,
_srv_resolver=None,
_well_known_cache=well_known_cache,
):
self._reactor = reactor
self._clock = Clock(reactor)
self._tls_client_options_factory = tls_client_options_factory
if _srv_resolver is None:
_srv_resolver = SrvResolver()
self._srv_resolver = _srv_resolver
self._pool = HTTPConnectionPool(reactor)
self._pool.retryAutomatically = False
self._pool.maxPersistentPerHost = 5
self._pool.cachedConnectionTimeout = 2 * 60
agent_args = {}
if _well_known_tls_policy is not None:
# the param is called 'contextFactory', but actually passing a
# contextfactory is deprecated, and it expects an IPolicyForHTTPS.
agent_args['contextFactory'] = _well_known_tls_policy
_well_known_agent = RedirectAgent(
Agent(self._reactor, pool=self._pool, **agent_args),
)
self._well_known_agent = _well_known_agent
# our cache of .well-known lookup results, mapping from server name
# to delegated name. The values can be:
# `bytes`: a valid server-name
# `None`: there is no (valid) .well-known here
self._well_known_cache = _well_known_cache
@defer.inlineCallbacks
def request(self, method, uri, headers=None, bodyProducer=None):
"""
Args:
method (bytes): HTTP method: GET/POST/etc
uri (bytes): Absolute URI to be retrieved
headers (twisted.web.http_headers.Headers|None):
HTTP headers to send with the request, or None to
send no extra headers.
bodyProducer (twisted.web.iweb.IBodyProducer|None):
An object which can generate bytes to make up the
body of this request (for example, the properly encoded contents of
a file for a file upload). Or None if the request is to have
no body.
Returns:
Deferred[twisted.web.iweb.IResponse]:
fires when the header of the response has been received (regardless of the
response status code). Fails if there is any problem which prevents that
response from being received (including problems that prevent the request
from being sent).
"""
parsed_uri = URI.fromBytes(uri, defaultPort=-1)
res = yield self._route_matrix_uri(parsed_uri)
# set up the TLS connection params
#
# XXX disabling TLS is really only supported here for the benefit of the
# unit tests. We should make the UTs cope with TLS rather than having to make
# the code support the unit tests.
if self._tls_client_options_factory is None:
tls_options = None
else:
tls_options = self._tls_client_options_factory.get_options(
res.tls_server_name.decode("ascii")
)
# make sure that the Host header is set correctly
if headers is None:
headers = Headers()
else:
headers = headers.copy()
if not headers.hasHeader(b'host'):
headers.addRawHeader(b'host', res.host_header)
class EndpointFactory(object):
@staticmethod
def endpointForURI(_uri):
ep = LoggingHostnameEndpoint(
self._reactor, res.target_host, res.target_port,
)
if tls_options is not None:
ep = wrapClientTLS(tls_options, ep)
return ep
agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
res = yield make_deferred_yieldable(
agent.request(method, uri, headers, bodyProducer)
)
defer.returnValue(res)
@defer.inlineCallbacks
def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
"""Helper for `request`: determine the routing for a Matrix URI
Args:
parsed_uri (twisted.web.client.URI): uri to route. Note that it should be
parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
if there is no explicit port given.
lookup_well_known (bool): True if we should look up the .well-known file if
there is no SRV record.
Returns:
Deferred[_RoutingResult]
"""
# check for an IP literal
try:
ip_address = IPAddress(parsed_uri.host.decode("ascii"))
except Exception:
# not an IP address
ip_address = None
if ip_address:
port = parsed_uri.port
if port == -1:
port = 8448
defer.returnValue(_RoutingResult(
host_header=parsed_uri.netloc,
tls_server_name=parsed_uri.host,
target_host=parsed_uri.host,
target_port=port,
))
if parsed_uri.port != -1:
# there is an explicit port
defer.returnValue(_RoutingResult(
host_header=parsed_uri.netloc,
tls_server_name=parsed_uri.host,
target_host=parsed_uri.host,
target_port=parsed_uri.port,
))
if lookup_well_known:
# try a .well-known lookup
well_known_server = yield self._get_well_known(parsed_uri.host)
if well_known_server:
# if we found a .well-known, start again, but don't do another
# .well-known lookup.
# parse the server name in the .well-known response into host/port.
# (This code is lifted from twisted.web.client.URI.fromBytes).
if b':' in well_known_server:
well_known_host, well_known_port = well_known_server.rsplit(b':', 1)
try:
well_known_port = int(well_known_port)
except ValueError:
# the part after the colon could not be parsed as an int
# - we assume it is an IPv6 literal with no port (the closing
# ']' stops it being parsed as an int)
well_known_host, well_known_port = well_known_server, -1
else:
well_known_host, well_known_port = well_known_server, -1
new_uri = URI(
scheme=parsed_uri.scheme,
netloc=well_known_server,
host=well_known_host,
port=well_known_port,
path=parsed_uri.path,
params=parsed_uri.params,
query=parsed_uri.query,
fragment=parsed_uri.fragment,
)
res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
defer.returnValue(res)
# try a SRV lookup
service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
server_list = yield self._srv_resolver.resolve_service(service_name)
if not server_list:
target_host = parsed_uri.host
port = 8448
logger.debug(
"No SRV record for %s, using %s:%i",
parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port,
)
else:
target_host, port = pick_server_from_list(server_list)
logger.debug(
"Picked %s:%i from SRV records for %s",
target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"),
)
defer.returnValue(_RoutingResult(
host_header=parsed_uri.netloc,
tls_server_name=parsed_uri.host,
target_host=target_host,
target_port=port,
))
@defer.inlineCallbacks
def _get_well_known(self, server_name):
"""Attempt to fetch and parse a .well-known file for the given server
Args:
server_name (bytes): name of the server, from the requested url
Returns:
Deferred[bytes|None]: either the new server name, from the .well-known, or
None if there was no .well-known file.
"""
try:
result = self._well_known_cache[server_name]
except KeyError:
# TODO: should we linearise so that we don't end up doing two .well-known
# requests for the same server in parallel?
with Measure(self._clock, "get_well_known"):
result, cache_period = yield self._do_get_well_known(server_name)
if cache_period > 0:
self._well_known_cache.set(server_name, result, cache_period)
defer.returnValue(result)
@defer.inlineCallbacks
def _do_get_well_known(self, server_name):
"""Actually fetch and parse a .well-known, without checking the cache
Args:
server_name (bytes): name of the server, from the requested url
Returns:
Deferred[Tuple[bytes|None|object],int]:
result, cache period, where result is one of:
- the new server name from the .well-known (as a `bytes`)
- None if there was no .well-known file.
- INVALID_WELL_KNOWN if the .well-known was invalid
"""
uri = b"https://%s/.well-known/matrix/server" % (server_name, )
uri_str = uri.decode("ascii")
logger.info("Fetching %s", uri_str)
try:
response = yield make_deferred_yieldable(
self._well_known_agent.request(b"GET", uri),
)
body = yield make_deferred_yieldable(readBody(response))
if response.code != 200:
raise Exception("Non-200 response %s" % (response.code, ))
parsed_body = json.loads(body.decode('utf-8'))
logger.info("Response from .well-known: %s", parsed_body)
if not isinstance(parsed_body, dict):
raise Exception("not a dict")
if "m.server" not in parsed_body:
raise Exception("Missing key 'm.server'")
except Exception as e:
logger.info("Error fetching %s: %s", uri_str, e)
# add some randomness to the TTL to avoid a stampeding herd every hour
# after startup
cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
defer.returnValue((None, cache_period))
result = parsed_body["m.server"].encode("ascii")
cache_period = _cache_period_from_headers(
response.headers,
time_now=self._reactor.seconds,
)
if cache_period is None:
cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD
# add some randomness to the TTL to avoid a stampeding herd every 24 hours
# after startup
cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
else:
cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD)
defer.returnValue((result, cache_period))
@implementer(IStreamClientEndpoint)
class LoggingHostnameEndpoint(object):
"""A wrapper for HostnameEndpint which logs when it connects"""
def __init__(self, reactor, host, port, *args, **kwargs):
self.host = host
self.port = port
self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs)
def connect(self, protocol_factory):
logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port)
return self.ep.connect(protocol_factory)
def _cache_period_from_headers(headers, time_now=time.time):
cache_controls = _parse_cache_control(headers)
if b'no-store' in cache_controls:
return 0
if b'max-age' in cache_controls:
try:
max_age = int(cache_controls[b'max-age'])
return max_age
except ValueError:
pass
expires = headers.getRawHeaders(b'expires')
if expires is not None:
try:
expires_date = stringToDatetime(expires[-1])
return expires_date - time_now()
except ValueError:
# RFC7234 says 'A cache recipient MUST interpret invalid date formats,
# especially the value "0", as representing a time in the past (i.e.,
# "already expired").
return 0
return None
def _parse_cache_control(headers):
cache_controls = {}
for hdr in headers.getRawHeaders(b'cache-control', []):
for directive in hdr.split(b','):
splits = [x.strip() for x in directive.split(b'=', 1)]
k = splits[0].lower()
v = splits[1] if len(splits) > 1 else None
cache_controls[k] = v
return cache_controls
@attr.s
class _RoutingResult(object):
"""The result returned by `_route_matrix_uri`.
Contains the parameters needed to direct a federation connection to a particular
server.
Where a SRV record points to several servers, this object contains a single server
chosen from the list.
"""
host_header = attr.ib()
"""
The value we should assign to the Host header (host:port from the matrix
URI, or .well-known).
:type: bytes
"""
tls_server_name = attr.ib()
"""
The server name we should set in the SNI (typically host, without port, from the
matrix URI or .well-known)
:type: bytes
"""
target_host = attr.ib()
"""
The hostname (or IP literal) we should route the TCP connection to (the target of the
SRV record, or the hostname from the URL/.well-known)
:type: bytes
"""
target_port = attr.ib()
"""
The port we should route the TCP connection to (the target of the SRV record, or
the port from the URL/.well-known, or 8448)
:type: int
"""

View file

@ -0,0 +1,169 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import time
import attr
from twisted.internet import defer
from twisted.internet.error import ConnectError
from twisted.names import client, dns
from twisted.names.error import DNSNameError, DomainError
from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__)
SERVER_CACHE = {}
@attr.s
class Server(object):
"""
Our record of an individual server which can be tried to reach a destination.
Attributes:
host (bytes): target hostname
port (int):
priority (int):
weight (int):
expires (int): when the cache should expire this record - in *seconds* since
the epoch
"""
host = attr.ib()
port = attr.ib()
priority = attr.ib(default=0)
weight = attr.ib(default=0)
expires = attr.ib(default=0)
def pick_server_from_list(server_list):
"""Randomly choose a server from the server list
Args:
server_list (list[Server]): list of candidate servers
Returns:
Tuple[bytes, int]: (host, port) pair for the chosen server
"""
if not server_list:
raise RuntimeError("pick_server_from_list called with empty list")
# TODO: currently we only use the lowest-priority servers. We should maintain a
# cache of servers known to be "down" and filter them out
min_priority = min(s.priority for s in server_list)
eligible_servers = list(s for s in server_list if s.priority == min_priority)
total_weight = sum(s.weight for s in eligible_servers)
target_weight = random.randint(0, total_weight)
for s in eligible_servers:
target_weight -= s.weight
if target_weight <= 0:
return s.host, s.port
# this should be impossible.
raise RuntimeError(
"pick_server_from_list got to end of eligible server list.",
)
class SrvResolver(object):
"""Interface to the dns client to do SRV lookups, with result caching.
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
but the cache never gets populated), so we add our own caching layer here.
Args:
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
cache (dict): cache object
get_time (callable): clock implementation. Should return seconds since the epoch
"""
def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
self._dns_client = dns_client
self._cache = cache
self._get_time = get_time
@defer.inlineCallbacks
def resolve_service(self, service_name):
"""Look up a SRV record
Args:
service_name (bytes): record to look up
Returns:
Deferred[list[Server]]:
a list of the SRV records, or an empty list if none found
"""
now = int(self._get_time())
if not isinstance(service_name, bytes):
raise TypeError("%r is not a byte string" % (service_name,))
cache_entry = self._cache.get(service_name, None)
if cache_entry:
if all(s.expires > now for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)
try:
answers, _, _ = yield make_deferred_yieldable(
self._dns_client.lookupService(service_name),
)
except DNSNameError:
# TODO: cache this. We can get the SOA out of the exception, and use
# the negative-TTL value.
defer.returnValue([])
except DomainError as e:
# We failed to resolve the name (other than a NameError)
# Try something in the cache, else rereaise
cache_entry = self._cache.get(service_name, None)
if cache_entry:
logger.warn(
"Failed to resolve %r, falling back to cache. %r",
service_name, e
)
defer.returnValue(list(cache_entry))
else:
raise e
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name)
servers = []
for answer in answers:
if answer.type != dns.SRV or not answer.payload:
continue
payload = answer.payload
servers.append(Server(
host=payload.target.name,
port=payload.port,
priority=payload.priority,
weight=payload.weight,
expires=now + answer.ttl,
))
self._cache[service_name] = list(servers)
defer.returnValue(servers)

View file

@ -19,7 +19,7 @@ import random
import sys
from io import BytesIO
from six import PY3, string_types
from six import PY3, raise_from, string_types
from six.moves import urllib
import attr
@ -32,7 +32,6 @@ from twisted.internet import defer, protocol
from twisted.internet.error import DNSLookupError
from twisted.internet.task import _EPSILON, Cooperator
from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, FileBodyProducer, HTTPConnectionPool
from twisted.web.http_headers import Headers
import synapse.metrics
@ -41,9 +40,11 @@ from synapse.api.errors import (
Codes,
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
SynapseError,
)
from synapse.http.endpoint import matrix_federation_endpoint
from synapse.http import QuieterFileBodyProducer
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.util.async_helpers import timeout_deferred
from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.metrics import Measure
@ -65,20 +66,6 @@ else:
MAXINT = sys.maxint
class MatrixFederationEndpointFactory(object):
def __init__(self, hs):
self.reactor = hs.get_reactor()
self.tls_client_options_factory = hs.tls_client_options_factory
def endpointForURI(self, uri):
destination = uri.netloc.decode('ascii')
return matrix_federation_endpoint(
self.reactor, destination, timeout=10,
tls_client_options_factory=self.tls_client_options_factory
)
_next_id = 1
@ -181,17 +168,15 @@ class MatrixFederationHttpClient(object):
requests.
"""
def __init__(self, hs):
def __init__(self, hs, tls_client_options_factory):
self.hs = hs
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
reactor = hs.get_reactor()
pool = HTTPConnectionPool(reactor)
pool.retryAutomatically = False
pool.maxPersistentPerHost = 5
pool.cachedConnectionTimeout = 2 * 60
self.agent = Agent.usingEndpointFactory(
reactor, MatrixFederationEndpointFactory(hs), pool=pool
self.agent = MatrixFederationAgent(
hs.get_reactor(),
tls_client_options_factory,
)
self.clock = hs.get_clock()
self._store = hs.get_datastore()
@ -228,19 +213,18 @@ class MatrixFederationHttpClient(object):
backoff_on_404 (bool): Back off if we get a 404
Returns:
Deferred: resolves with the http response object on success.
Deferred[twisted.web.client.Response]: resolves with the HTTP
response object on success.
Fails with ``HttpResponseException``: if we get an HTTP response
code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
(May also fail with plenty of other Exceptions for things like DNS
failures, connection failures, SSL failures.)
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
"""
if timeout:
_sec_timeout = timeout / 1000
@ -271,7 +255,6 @@ class MatrixFederationHttpClient(object):
headers_dict = {
b"User-Agent": [self.version_string_bytes],
b"Host": [destination_bytes],
}
with limiter:
@ -298,60 +281,111 @@ class MatrixFederationHttpClient(object):
json = request.get_json()
if json:
headers_dict[b"Content-Type"] = [b"application/json"]
self.sign_request(
auth_headers = self.build_auth_headers(
destination_bytes, method_bytes, url_to_sign_bytes,
headers_dict, json,
json,
)
data = encode_canonical_json(json)
producer = FileBodyProducer(
producer = QuieterFileBodyProducer(
BytesIO(data),
cooperator=self._cooperator,
)
else:
producer = None
self.sign_request(
auth_headers = self.build_auth_headers(
destination_bytes, method_bytes, url_to_sign_bytes,
headers_dict,
)
headers_dict[b"Authorization"] = auth_headers
logger.info(
"{%s} [%s] Sending request: %s %s",
"{%s} [%s] Sending request: %s %s; timeout %fs",
request.txn_id, request.destination, request.method,
url_str,
url_str, _sec_timeout,
)
# we don't want all the fancy cookie and redirect handling that
# treq.request gives: just use the raw Agent.
request_deferred = self.agent.request(
method_bytes,
url_bytes,
headers=Headers(headers_dict),
bodyProducer=producer,
try:
with Measure(self.clock, "outbound_request"):
# we don't want all the fancy cookie and redirect handling
# that treq.request gives: just use the raw Agent.
request_deferred = self.agent.request(
method_bytes,
url_bytes,
headers=Headers(headers_dict),
bodyProducer=producer,
)
request_deferred = timeout_deferred(
request_deferred,
timeout=_sec_timeout,
reactor=self.hs.get_reactor(),
)
response = yield request_deferred
except DNSLookupError as e:
raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e)
except Exception as e:
logger.info("Failed to send request: %s", e)
raise_from(RequestSendFailed(e, can_retry=True), e)
logger.info(
"{%s} [%s] Got response headers: %d %s",
request.txn_id,
request.destination,
response.code,
response.phrase.decode('ascii', errors='replace'),
)
request_deferred = timeout_deferred(
request_deferred,
timeout=_sec_timeout,
reactor=self.hs.get_reactor(),
)
with Measure(self.clock, "outbound_request"):
response = yield make_deferred_yieldable(
request_deferred,
if 200 <= response.code < 300:
pass
else:
# :'(
# Update transactions table?
d = treq.content(response)
d = timeout_deferred(
d,
timeout=_sec_timeout,
reactor=self.hs.get_reactor(),
)
try:
body = yield make_deferred_yieldable(d)
except Exception as e:
# Eh, we're already going to raise an exception so lets
# ignore if this fails.
logger.warn(
"{%s} [%s] Failed to get error response: %s %s: %s",
request.txn_id,
request.destination,
request.method,
url_str,
_flatten_response_never_received(e),
)
body = None
e = HttpResponseException(
response.code, response.phrase, body
)
# Retry if the error is a 429 (Too Many Requests),
# otherwise just raise a standard HttpResponseException
if response.code == 429:
raise_from(RequestSendFailed(e, can_retry=True), e)
else:
raise e
break
except Exception as e:
except RequestSendFailed as e:
logger.warn(
"{%s} [%s] Request failed: %s %s: %s",
request.txn_id,
request.destination,
request.method,
url_str,
_flatten_response_never_received(e),
_flatten_response_never_received(e.inner_exception),
)
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
if not e.can_retry:
raise
if retries_left and not timeout:
@ -376,50 +410,36 @@ class MatrixFederationHttpClient(object):
else:
raise
logger.info(
"{%s} [%s] Got response headers: %d %s",
request.txn_id,
request.destination,
response.code,
response.phrase.decode('ascii', errors='replace'),
)
if 200 <= response.code < 300:
pass
else:
# :'(
# Update transactions table?
d = treq.content(response)
d = timeout_deferred(
d,
timeout=_sec_timeout,
reactor=self.hs.get_reactor(),
)
body = yield make_deferred_yieldable(d)
raise HttpResponseException(
response.code, response.phrase, body
)
except Exception as e:
logger.warn(
"{%s} [%s] Request failed: %s %s: %s",
request.txn_id,
request.destination,
request.method,
url_str,
_flatten_response_never_received(e),
)
raise
defer.returnValue(response)
def sign_request(self, destination, method, url_bytes, headers_dict,
content=None, destination_is=None):
def build_auth_headers(
self, destination, method, url_bytes, content=None, destination_is=None,
):
"""
Signs a request by adding an Authorization header to headers_dict
Builds the Authorization headers for a federation request
Args:
destination (bytes|None): The desination home server of the request.
May be None if the destination is an identity server, in which case
destination_is must be non-None.
method (bytes): The HTTP method of the request
url_bytes (bytes): The URI path of the request
headers_dict (dict[bytes, list[bytes]]): Dictionary of request headers to
append to
content (object): The body of the request
destination_is (bytes): As 'destination', but if the destination is an
identity server
Returns:
None
list[bytes]: a list of headers to be added as "Authorization:" headers
"""
request = {
"method": method,
@ -446,8 +466,7 @@ class MatrixFederationHttpClient(object):
self.server_name, key, sig,
)).encode('ascii')
)
headers_dict[b"Authorization"] = auth_headers
return auth_headers
@defer.inlineCallbacks
def put_json(self, destination, path, args={}, data={},
@ -477,17 +496,18 @@ class MatrixFederationHttpClient(object):
requests)
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body.
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Fails with ``HttpResponseException`` if we get an HTTP response
code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
@ -531,17 +551,18 @@ class MatrixFederationHttpClient(object):
try the request anyway.
args (dict): query params
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body.
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Fails with ``HttpResponseException`` if we get an HTTP response
code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
@ -586,17 +607,18 @@ class MatrixFederationHttpClient(object):
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body.
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Fails with ``HttpResponseException`` if we get an HTTP response
code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
"""
logger.debug("get_json args: %s", args)
@ -637,17 +659,18 @@ class MatrixFederationHttpClient(object):
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body.
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Fails with ``HttpResponseException`` if we get an HTTP response
code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
method="DELETE",
@ -680,18 +703,20 @@ class MatrixFederationHttpClient(object):
args (dict): Optional dictionary used to create the query string.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
Returns:
Deferred: resolves with an (int,dict) tuple of the file length and
a dict of the response headers.
Deferred[tuple[int, dict]]: Resolves with an (int,dict) tuple of
the file length and a dict of the response headers.
Fails with ``HttpResponseException`` if we get an HTTP response code
>= 300
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
method="GET",
@ -784,21 +809,21 @@ def check_content_type_is_json(headers):
headers (twisted.web.http_headers.Headers): headers to check
Raises:
RuntimeError if the
RequestSendFailed: if the Content-Type header is missing or isn't JSON
"""
c_type = headers.getRawHeaders(b"Content-Type")
if c_type is None:
raise RuntimeError(
raise RequestSendFailed(RuntimeError(
"No Content-Type header"
)
), can_retry=False)
c_type = c_type[0].decode('ascii') # only the first header
val, options = cgi.parse_header(c_type)
if val != "application/json":
raise RuntimeError(
raise RequestSendFailed(RuntimeError(
"Content-Type not application/json: was '%s'" % c_type
)
), can_retry=False)
def encode_query_args(args):

View file

@ -106,10 +106,10 @@ def wrap_json_request_handler(h):
# trace.
f = failure.Failure()
logger.error(
"Failed handle request via %r: %r: %s",
h,
"Failed handle request via %r: %r",
request.request_metrics.name,
request,
f.getTraceback().rstrip(),
exc_info=(f.type, f.value, f.getTracebackObject()),
)
# Only respond with an error response if we haven't already started
# writing, otherwise lets just kill the connection
@ -468,13 +468,13 @@ def set_cors_headers(request):
Args:
request (twisted.web.http.Request): The http request to add CORs to.
"""
request.setHeader("Access-Control-Allow-Origin", "*")
request.setHeader(b"Access-Control-Allow-Origin", b"*")
request.setHeader(
"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"
b"Access-Control-Allow-Methods", b"GET, POST, PUT, DELETE, OPTIONS"
)
request.setHeader(
"Access-Control-Allow-Headers",
"Origin, X-Requested-With, Content-Type, Accept, Authorization"
b"Access-Control-Allow-Headers",
b"Origin, X-Requested-With, Content-Type, Accept, Authorization"
)

View file

@ -121,16 +121,15 @@ def parse_string(request, name, default=None, required=False,
Args:
request: the twisted HTTP request.
name (bytes/unicode): the name of the query parameter.
default (bytes/unicode|None): value to use if the parameter is absent,
name (bytes|unicode): the name of the query parameter.
default (bytes|unicode|None): value to use if the parameter is absent,
defaults to None. Must be bytes if encoding is None.
required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
allowed_values (list[bytes/unicode]): List of allowed values for the
allowed_values (list[bytes|unicode]): List of allowed values for the
string, or None if any value is allowed, defaults to None. Must be
the same type as name, if given.
encoding: The encoding to decode the name to, and decode the string
content with.
encoding (str|None): The encoding to decode the string content with.
Returns:
bytes/unicode|None: A string value or the default. Unicode if encoding

View file

@ -274,8 +274,6 @@ pending_calls_metric = Histogram(
# Federation Metrics
#
sent_edus_counter = Counter("synapse_federation_client_sent_edus", "")
sent_transactions_counter = Counter("synapse_federation_client_sent_transactions", "")
events_processed_counter = Counter("synapse_federation_client_events_processed", "")

View file

@ -79,7 +79,7 @@ class ModuleApi(object):
Returns:
Deferred: a 2-tuple of (user_id, access_token)
"""
reg = self.hs.get_handlers().registration_handler
reg = self.hs.get_registration_handler()
return reg.register(localpart=localpart)
@defer.inlineCallbacks

View file

@ -84,7 +84,7 @@ def _rule_to_template(rule):
templaterule["pattern"] = thecond["pattern"]
if unscoped_rule_id:
templaterule['rule_id'] = unscoped_rule_id
templaterule['rule_id'] = unscoped_rule_id
if 'default' in rule:
templaterule['default'] = rule['default']
return templaterule

Some files were not shown because too many files have changed in this diff Show more