merge proper fix to bug 2969

This commit is contained in:
Matthew Hodgson 2018-03-13 22:11:58 +00:00
commit 12350e3f9a
46 changed files with 578 additions and 338 deletions

View File

@ -32,3 +32,30 @@ specified by including an event_id in the URI, or by setting a
id is given, that event (and others at the same graph depth) will be retained. id is given, that event (and others at the same graph depth) will be retained.
If ``purge_up_to_ts`` is given, it should be a timestamp since the unix epoch, If ``purge_up_to_ts`` is given, it should be a timestamp since the unix epoch,
in milliseconds. in milliseconds.
The API starts the purge running, and returns immediately with a JSON body with
a purge id:
.. code:: json
{
"purge_id": "<opaque id>"
}
Purge status query
------------------
It is possible to poll for updates on recent purges with a second API;
``GET /_matrix/client/r0/admin/purge_history_status/<purge_id>``
(again, with a suitable ``access_token``). This API returns a JSON body like
the following:
.. code:: json
{
"status": "active"
}
The status will be one of ``active``, ``complete``, or ``failed``.

View File

@ -279,9 +279,9 @@ Obviously that option means that the operations done in
that might be fixed by setting a different logcontext via a ``with that might be fixed by setting a different logcontext via a ``with
LoggingContext(...)`` in ``background_operation``). LoggingContext(...)`` in ``background_operation``).
The second option is to use ``logcontext.preserve_fn``, which wraps a function The second option is to use ``logcontext.run_in_background``, which wraps a
so that it doesn't reset the logcontext even when it returns an incomplete function so that it doesn't reset the logcontext even when it returns an
deferred, and adds a callback to the returned deferred to reset the incomplete deferred, and adds a callback to the returned deferred to reset the
logcontext. In other words, it turns a function that follows the Synapse rules logcontext. In other words, it turns a function that follows the Synapse rules
about logcontexts and Deferreds into one which behaves more like an external about logcontexts and Deferreds into one which behaves more like an external
function — the opposite operation to that described in the previous section. function — the opposite operation to that described in the previous section.
@ -293,7 +293,7 @@ It can be used like this:
def do_request_handling(): def do_request_handling():
yield foreground_operation() yield foreground_operation()
logcontext.preserve_fn(background_operation)() logcontext.run_in_background(background_operation)
# this will now be logged against the request context # this will now be logged against the request context
logger.debug("Request handling complete") logger.debug("Request handling complete")

View File

@ -156,7 +156,6 @@ def start(config_options):
) )
ss.setup() ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners) ss.start_listening(config.worker_listeners)
def start(): def start():

View File

@ -161,7 +161,6 @@ def start(config_options):
) )
ss.setup() ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners) ss.start_listening(config.worker_listeners)
def start(): def start():

View File

@ -144,7 +144,6 @@ def start(config_options):
) )
ss.setup() ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners) ss.start_listening(config.worker_listeners)
def start(): def start():

View File

@ -211,7 +211,6 @@ def start(config_options):
) )
ss.setup() ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners) ss.start_listening(config.worker_listeners)
def start(): def start():

View File

@ -348,7 +348,7 @@ def setup(config_options):
hs.get_state_handler().start_caching() hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling() hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates() hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache() hs.get_federation_client().start_get_pdu_cache()
register_memory_metrics(hs) register_memory_metrics(hs)

View File

@ -158,7 +158,6 @@ def start(config_options):
) )
ss.setup() ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners) ss.start_listening(config.worker_listeners)
def start(): def start():

View File

@ -15,11 +15,3 @@
""" This package includes all the federation specific logic. """ This package includes all the federation specific logic.
""" """
from .replication import ReplicationLayer
def initialize_http_replication(hs):
transport = hs.get_federation_transport_client()
return ReplicationLayer(hs, transport)

View File

@ -27,7 +27,13 @@ logger = logging.getLogger(__name__)
class FederationBase(object): class FederationBase(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.store = hs.get_datastore()
self._clock = hs.get_clock()
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False, def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,

View File

@ -58,6 +58,7 @@ class FederationClient(FederationBase):
self._clear_tried_cache, 60 * 1000, self._clear_tried_cache, 60 * 1000,
) )
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
def _clear_tried_cache(self): def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache""" """Clear pdu_destination_tried cache"""

View File

@ -17,12 +17,14 @@ import logging
import simplejson as json import simplejson as json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, FederationError, SynapseError from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
from synapse.federation.federation_base import ( from synapse.federation.federation_base import (
FederationBase, FederationBase,
event_from_pdu_json, event_from_pdu_json,
) )
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction from synapse.federation.units import Edu, Transaction
import synapse.metrics import synapse.metrics
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
@ -52,50 +54,19 @@ class FederationServer(FederationBase):
super(FederationServer, self).__init__(hs) super(FederationServer, self).__init__(hs)
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.handler = hs.get_handlers().federation_handler
self._server_linearizer = async.Linearizer("fed_server") self._server_linearizer = async.Linearizer("fed_server")
self._transaction_linearizer = async.Linearizer("fed_txn_handler") self._transaction_linearizer = async.Linearizer("fed_txn_handler")
self.transaction_actions = TransactionActions(self.store)
self.registry = hs.get_federation_registry()
# We cache responses to state queries, as they take a while and often # We cache responses to state queries, as they take a while and often
# come in waves. # come in waves.
self._state_resp_cache = ResponseCache(hs, timeout_ms=30000) self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are
documented on :py:class:`.ReplicationHandler`.
"""
self.handler = handler
def register_edu_handler(self, edu_type, handler):
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation Query of the given type.
Args:
query_type (str): Category name of the query, which should match
the string used by make_query.
handler (callable): Invoked to handle incoming queries of this type
handler is invoked as:
result = handler(args)
where 'args' is a dict mapping strings to strings of the query
arguments. It should return a Deferred that will eventually yield an
object to encode as JSON.
"""
if query_type in self.query_handlers:
raise KeyError(
"Already have a Query handler for %s" % (query_type,)
)
self.query_handlers[query_type] = handler
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_backfill_request(self, origin, room_id, versions, limit): def on_backfill_request(self, origin, room_id, versions, limit):
@ -229,16 +200,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def received_edu(self, origin, edu_type, content): def received_edu(self, origin, edu_type, content):
received_edus_counter.inc() received_edus_counter.inc()
yield self.registry.on_edu(edu_type, origin, content)
if edu_type in self.edu_handlers:
try:
yield self.edu_handlers[edu_type](origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e:
logger.exception("Failed to handle edu %r", edu_type)
else:
logger.warn("Received EDU of type %s with no handler", edu_type)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -328,14 +290,8 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_query_request(self, query_type, args): def on_query_request(self, query_type, args):
received_queries_counter.inc(query_type) received_queries_counter.inc(query_type)
resp = yield self.registry.on_query(query_type, args)
if query_type in self.query_handlers: defer.returnValue((200, resp))
response = yield self.query_handlers[query_type](args)
defer.returnValue((200, response))
else:
defer.returnValue(
(404, "No handler for Query type '%s'" % (query_type,))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_make_join_request(self, room_id, user_id): def on_make_join_request(self, room_id, user_id):
@ -607,3 +563,66 @@ class FederationServer(FederationBase):
origin, room_id, event_dict origin, room_id, event_dict
) )
defer.returnValue(ret) defer.returnValue(ret)
class FederationHandlerRegistry(object):
"""Allows classes to register themselves as handlers for a given EDU or
query type for incoming federation traffic.
"""
def __init__(self):
self.edu_handlers = {}
self.query_handlers = {}
def register_edu_handler(self, edu_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation EDU of the given type.
Args:
edu_type (str): The type of the incoming EDU to register handler for
handler (Callable[[str, dict]]): A callable invoked on incoming EDU
of the given type. The arguments are the origin server name and
the EDU contents.
"""
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation query of the given type.
Args:
query_type (str): Category name of the query, which should match
the string used by make_query.
handler (Callable[[dict], Deferred[dict]]): Invoked to handle
incoming queries of this type. The return will be yielded
on and the result used as the response to the query request.
"""
if query_type in self.query_handlers:
raise KeyError(
"Already have a Query handler for %s" % (query_type,)
)
self.query_handlers[query_type] = handler
@defer.inlineCallbacks
def on_edu(self, edu_type, origin, content):
handler = self.edu_handlers.get(edu_type)
if not handler:
logger.warn("No handler registered for EDU type %s", edu_type)
try:
yield handler(origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e:
logger.exception("Failed to handle edu %r", edu_type)
def on_query(self, query_type, args):
handler = self.query_handlers.get(query_type)
if not handler:
logger.warn("No handler registered for query type %s", query_type)
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
return handler(args)

View File

@ -1,73 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This layer is responsible for replicating with remote home servers using
a given transport.
"""
from .federation_client import FederationClient
from .federation_server import FederationServer
from .persistence import TransactionActions
import logging
logger = logging.getLogger(__name__)
class ReplicationLayer(FederationClient, FederationServer):
"""This layer is responsible for replicating with remote home servers over
the given transport. I.e., does the sending and receiving of PDUs to
remote home servers.
The layer communicates with the rest of the server via a registered
ReplicationHandler.
In more detail, the layer:
* Receives incoming data and processes it into transactions and pdus.
* Fetches any PDUs it thinks it might have missed.
* Keeps the current state for contexts up to date by applying the
suitable conflict resolution.
* Sends outgoing pdus wrapped in transactions.
* Fills out the references to previous pdus/transactions appropriately
for outgoing data.
"""
def __init__(self, hs, transport_layer):
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.transport_layer = transport_layer
self.federation_client = self
self.store = hs.get_datastore()
self.handler = None
self.edu_handlers = {}
self.query_handlers = {}
self._clock = hs.get_clock()
self.transaction_actions = TransactionActions(self.store)
self.hs = hs
super(ReplicationLayer, self).__init__(hs)
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name

View File

@ -1190,7 +1190,7 @@ GROUP_ATTESTATION_SERVLET_CLASSES = (
def register_servlets(hs, resource, authenticator, ratelimiter): def register_servlets(hs, resource, authenticator, ratelimiter):
for servletclass in FEDERATION_SERVLET_CLASSES: for servletclass in FEDERATION_SERVLET_CLASSES:
servletclass( servletclass(
handler=hs.get_replication_layer(), handler=hs.get_federation_server(),
authenticator=authenticator, authenticator=authenticator,
ratelimiter=ratelimiter, ratelimiter=ratelimiter,
server_name=hs.hostname, server_name=hs.hostname,

View File

@ -37,14 +37,15 @@ class DeviceHandler(BaseHandler):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self.federation_sender = hs.get_federation_sender() self.federation_sender = hs.get_federation_sender()
self.federation = hs.get_replication_layer()
self._edu_updater = DeviceListEduUpdater(hs, self) self._edu_updater = DeviceListEduUpdater(hs, self)
self.federation.register_edu_handler( federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler(
"m.device_list_update", self._edu_updater.incoming_device_list_update, "m.device_list_update", self._edu_updater.incoming_device_list_update,
) )
self.federation.register_query_handler( federation_registry.register_query_handler(
"user_devices", self.on_federation_query_user_devices, "user_devices", self.on_federation_query_user_devices,
) )
@ -430,7 +431,7 @@ class DeviceListEduUpdater(object):
def __init__(self, hs, device_handler): def __init__(self, hs, device_handler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation = hs.get_replication_layer() self.federation = hs.get_federation_client()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.device_handler = device_handler self.device_handler = device_handler

View File

@ -37,7 +37,7 @@ class DeviceMessageHandler(object):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler( hs.get_federation_registry().register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu "m.direct_to_device", self.on_direct_to_device_edu
) )

View File

@ -36,8 +36,8 @@ class DirectoryHandler(BaseHandler):
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.federation = hs.get_replication_layer() self.federation = hs.get_federation_client()
self.federation.register_query_handler( hs.get_federation_registry().register_query_handler(
"directory", self.on_directory_query "directory", self.on_directory_query
) )

View File

@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
class E2eKeysHandler(object): class E2eKeysHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation = hs.get_replication_layer() self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -40,7 +40,7 @@ class E2eKeysHandler(object):
# doesn't really work as part of the generic query API, because the # doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the # query request requires an object POST, but we abuse the
# "query handler" interface. # "query handler" interface.
self.federation.register_query_handler( hs.get_federation_registry().register_query_handler(
"client_keys", self.on_federation_query_client_keys "client_keys", self.on_federation_query_client_keys
) )

View File

@ -68,7 +68,7 @@ class FederationHandler(BaseHandler):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.replication_layer = hs.get_replication_layer() self.replication_layer = hs.get_federation_client()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
@ -78,8 +78,6 @@ class FederationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.replication_layer.set_handler(self)
# When joining a room we need to queue any events for that room up # When joining a room we need to queue any events for that room up
self.room_queues = {} self.room_queues = {}
self._room_pdu_linearizer = Linearizer("fed_room_pdu") self._room_pdu_linearizer = Linearizer("fed_room_pdu")

View File

@ -13,7 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer, reactor
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError
@ -24,9 +25,10 @@ from synapse.types import (
UserID, RoomAlias, RoomStreamToken, UserID, RoomAlias, RoomStreamToken,
) )
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn, run_in_background
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from synapse.replication.http.send_event import send_event_to_master from synapse.replication.http.send_event import send_event_to_master
@ -41,6 +43,36 @@ import ujson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PurgeStatus(object):
"""Object tracking the status of a purge request
This class contains information on the progress of a purge request, for
return by get_purge_status.
Attributes:
status (int): Tracks whether this request has completed. One of
STATUS_{ACTIVE,COMPLETE,FAILED}
"""
STATUS_ACTIVE = 0
STATUS_COMPLETE = 1
STATUS_FAILED = 2
STATUS_TEXT = {
STATUS_ACTIVE: "active",
STATUS_COMPLETE: "complete",
STATUS_FAILED: "failed",
}
def __init__(self):
self.status = PurgeStatus.STATUS_ACTIVE
def asdict(self):
return {
"status": PurgeStatus.STATUS_TEXT[self.status]
}
class MessageHandler(BaseHandler): class MessageHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
@ -50,14 +82,87 @@ class MessageHandler(BaseHandler):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.pagination_lock = ReadWriteLock() self.pagination_lock = ReadWriteLock()
self._purges_in_progress_by_room = set()
# map from purge id to PurgeStatus
self._purges_by_id = {}
def start_purge_history(self, room_id, topological_ordering,
delete_local_events=False):
"""Start off a history purge on a room.
Args:
room_id (str): The room to purge from
topological_ordering (int): minimum topo ordering to preserve
delete_local_events (bool): True to delete local events as well as
remote ones
Returns:
str: unique ID for this purge transaction.
"""
if room_id in self._purges_in_progress_by_room:
raise SynapseError(
400,
"History purge already in progress for %s" % (room_id, ),
)
purge_id = random_string(16)
# we log the purge_id here so that it can be tied back to the
# request id in the log lines.
logger.info("[purge] starting purge_id %s", purge_id)
self._purges_by_id[purge_id] = PurgeStatus()
run_in_background(
self._purge_history,
purge_id, room_id, topological_ordering, delete_local_events,
)
return purge_id
@defer.inlineCallbacks @defer.inlineCallbacks
def purge_history(self, room_id, topological_ordering, def _purge_history(self, purge_id, room_id, topological_ordering,
delete_local_events=False): delete_local_events):
"""Carry out a history purge on a room.
Args:
purge_id (str): The id for this purge
room_id (str): The room to purge from
topological_ordering (int): minimum topo ordering to preserve
delete_local_events (bool): True to delete local events as well as
remote ones
Returns:
Deferred
"""
self._purges_in_progress_by_room.add(room_id)
try:
with (yield self.pagination_lock.write(room_id)): with (yield self.pagination_lock.write(room_id)):
yield self.store.purge_history( yield self.store.purge_history(
room_id, topological_ordering, delete_local_events, room_id, topological_ordering, delete_local_events,
) )
logger.info("[purge] complete")
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE
except Exception:
logger.error("[purge] failed: %s", Failure().getTraceback().rstrip())
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
finally:
self._purges_in_progress_by_room.discard(room_id)
# remove the purge from the list 24 hours after it completes
def clear_purge():
del self._purges_by_id[purge_id]
reactor.callLater(24 * 3600, clear_purge)
def get_purge_status(self, purge_id):
"""Get the current status of an active purge
Args:
purge_id (str): purge_id returned by start_purge_history
Returns:
PurgeStatus|None
"""
return self._purges_by_id.get(purge_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None, def get_messages(self, requester, room_id=None, pagin_config=None,
@ -562,7 +667,7 @@ class EventCreationHandler(object):
event (FrozenEvent) event (FrozenEvent)
context (EventContext) context (EventContext)
ratelimit (bool) ratelimit (bool)
extra_users (list(str)): Any extra users to notify about event extra_users (list(UserID)): Any extra users to notify about event
""" """
try: try:

View File

@ -93,29 +93,30 @@ class PresenceHandler(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.wheel_timer = WheelTimer() self.wheel_timer = WheelTimer()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.replication = hs.get_replication_layer()
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.replication.register_edu_handler( federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler(
"m.presence", self.incoming_presence "m.presence", self.incoming_presence
) )
self.replication.register_edu_handler( federation_registry.register_edu_handler(
"m.presence_invite", "m.presence_invite",
lambda origin, content: self.invite_presence( lambda origin, content: self.invite_presence(
observed_user=UserID.from_string(content["observed_user"]), observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]), observer_user=UserID.from_string(content["observer_user"]),
) )
) )
self.replication.register_edu_handler( federation_registry.register_edu_handler(
"m.presence_accept", "m.presence_accept",
lambda origin, content: self.accept_presence( lambda origin, content: self.accept_presence(
observed_user=UserID.from_string(content["observed_user"]), observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]), observer_user=UserID.from_string(content["observer_user"]),
) )
) )
self.replication.register_edu_handler( federation_registry.register_edu_handler(
"m.presence_deny", "m.presence_deny",
lambda origin, content: self.deny_presence( lambda origin, content: self.deny_presence(
observed_user=UserID.from_string(content["observed_user"]), observed_user=UserID.from_string(content["observed_user"]),

View File

@ -31,8 +31,8 @@ class ProfileHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(ProfileHandler, self).__init__(hs) super(ProfileHandler, self).__init__(hs)
self.federation = hs.get_replication_layer() self.federation = hs.get_federation_client()
self.federation.register_query_handler( hs.get_federation_registry().register_query_handler(
"profile", self.on_profile_query "profile", self.on_profile_query
) )

View File

@ -35,7 +35,7 @@ class ReceiptsHandler(BaseHandler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler( hs.get_federation_registry().register_edu_handler(
"m.receipt", self._received_remote_receipt "m.receipt", self._received_remote_receipt
) )
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()

View File

@ -446,16 +446,34 @@ class RegistrationHandler(BaseHandler):
return self.hs.get_auth_handler() return self.hs.get_auth_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def guest_access_token_for(self, medium, address, inviter_user_id): def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
"""Get a guest access token for a 3PID, creating a guest account if
one doesn't already exist.
Args:
medium (str)
address (str)
inviter_user_id (str): The user ID who is trying to invite the
3PID
Returns:
Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
3PID guest account.
"""
access_token = yield self.store.get_3pid_guest_access_token(medium, address) access_token = yield self.store.get_3pid_guest_access_token(medium, address)
if access_token: if access_token:
defer.returnValue(access_token) user_info = yield self.auth.get_user_by_access_token(
access_token
)
_, access_token = yield self.register( defer.returnValue((user_info["user"].to_string(), access_token))
user_id, access_token = yield self.register(
generate_token=True, generate_token=True,
make_guest=True make_guest=True
) )
access_token = yield self.store.save_or_get_3pid_guest_access_token( access_token = yield self.store.save_or_get_3pid_guest_access_token(
medium, address, access_token, inviter_user_id medium, address, access_token, inviter_user_id
) )
defer.returnValue(access_token)
defer.returnValue((user_id, access_token))

View File

@ -409,7 +409,7 @@ class RoomListHandler(BaseHandler):
def _get_remote_list_cached(self, server_name, limit=None, since_token=None, def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
search_filter=None, include_all_networks=False, search_filter=None, include_all_networks=False,
third_party_instance_id=None,): third_party_instance_id=None,):
repl_layer = self.hs.get_replication_layer() repl_layer = self.hs.get_federation_client()
if search_filter: if search_filter:
# We can't cache when asking for search # We can't cache when asking for search
return repl_layer.get_public_rooms( return repl_layer.get_public_rooms(

View File

@ -55,7 +55,6 @@ class RoomMemberHandler(object):
self.registration_handler = hs.get_handlers().registration_handler self.registration_handler = hs.get_handlers().registration_handler
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.event_creation_hander = hs.get_event_creation_handler() self.event_creation_hander = hs.get_event_creation_handler()
self.replication_layer = hs.get_replication_layer()
self.member_linearizer = Linearizer(name="member") self.member_linearizer = Linearizer(name="member")
@ -138,7 +137,20 @@ class RoomMemberHandler(object):
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def remote_join(self, remote_room_hosts, room_id, user, content): def _remote_join(self, remote_room_hosts, room_id, user, content):
"""Try and join a room that this server is not in
Args:
remote_room_hosts (list[str]): List of servers that can be used
to join via.
room_id (str): Room that we are trying to join
user (UserID): User who is trying to join
content (dict): A dict that should be used as the content of the
join event.
Returns:
Deferred
"""
if len(remote_room_hosts) == 0: if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers") raise SynapseError(404, "No known servers")
@ -154,6 +166,43 @@ class RoomMemberHandler(object):
) )
yield user_joined_room(self.distributor, user, room_id) yield user_joined_room(self.distributor, user, room_id)
@defer.inlineCallbacks
def _remote_reject_invite(self, remote_room_hosts, room_id, target):
"""Attempt to reject an invite for a room this server is not in. If we
fail to do so we locally mark the invite as rejected.
Args:
remote_room_hosts (list[str]): List of servers to use to try and
reject invite
room_id (str)
target (UserID): The user rejecting the invite
Returns:
Deferred[dict]: A dictionary to be returned to the client, may
include event_id etc, or nothing if we locally rejected
"""
fed_handler = self.federation_handler
try:
ret = yield fed_handler.do_remotely_reject_invite(
remote_room_hosts,
room_id,
target.to_string(),
)
defer.returnValue(ret)
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
#
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
logger.warn("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(
target.to_string(), room_id
)
defer.returnValue({})
@defer.inlineCallbacks @defer.inlineCallbacks
def update_membership( def update_membership(
self, self,
@ -212,7 +261,7 @@ class RoomMemberHandler(object):
# if this is a join with a 3pid signature, we may need to turn a 3pid # if this is a join with a 3pid signature, we may need to turn a 3pid
# invite into a normal invite before we can handle the join. # invite into a normal invite before we can handle the join.
if third_party_signed is not None: if third_party_signed is not None:
yield self.replication_layer.exchange_third_party_invite( yield self.federation_handler.exchange_third_party_invite(
third_party_signed["sender"], third_party_signed["sender"],
target.to_string(), target.to_string(),
room_id, room_id,
@ -292,7 +341,7 @@ class RoomMemberHandler(object):
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
if not is_host_in_room: if not is_host_in_room:
inviter = yield self.get_inviter(target.to_string(), room_id) inviter = yield self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter): if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain) remote_room_hosts.append(inviter.domain)
@ -306,7 +355,7 @@ class RoomMemberHandler(object):
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
ret = yield self.remote_join( ret = yield self._remote_join(
remote_room_hosts, room_id, target, content remote_room_hosts, room_id, target, content
) )
defer.returnValue(ret) defer.returnValue(ret)
@ -314,7 +363,7 @@ class RoomMemberHandler(object):
elif effective_membership_state == Membership.LEAVE: elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room: if not is_host_in_room:
# perhaps we've been invited # perhaps we've been invited
inviter = yield self.get_inviter(target.to_string(), room_id) inviter = yield self._get_inviter(target.to_string(), room_id)
if not inviter: if not inviter:
raise SynapseError(404, "Not a known room") raise SynapseError(404, "Not a known room")
@ -328,28 +377,10 @@ class RoomMemberHandler(object):
else: else:
# send the rejection to the inviter's HS. # send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain] remote_room_hosts = remote_room_hosts + [inviter.domain]
fed_handler = self.federation_handler res = yield self._remote_reject_invite(
try: remote_room_hosts, room_id, target,
ret = yield fed_handler.do_remotely_reject_invite(
remote_room_hosts,
room_id,
target.to_string(),
) )
defer.returnValue(ret) defer.returnValue(res)
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
#
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
logger.warn("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(
target.to_string(), room_id
)
defer.returnValue({})
res = yield self._local_membership_update( res = yield self._local_membership_update(
requester=requester, requester=requester,
@ -496,7 +527,7 @@ class RoomMemberHandler(object):
defer.returnValue((RoomID.from_string(room_id), servers)) defer.returnValue((RoomID.from_string(room_id), servers))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_inviter(self, user_id, room_id): def _get_inviter(self, user_id, room_id):
invite = yield self.store.get_invite_for_user_in_room( invite = yield self.store.get_invite_for_user_in_room(
user_id=user_id, user_id=user_id,
room_id=room_id, room_id=room_id,
@ -573,7 +604,7 @@ class RoomMemberHandler(object):
if "mxid" in data: if "mxid" in data:
if "signatures" not in data: if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding") raise AuthError(401, "No signatures on 3pid binding")
yield self.verify_any_signature(data, id_server) yield self._verify_any_signature(data, id_server)
defer.returnValue(data["mxid"]) defer.returnValue(data["mxid"])
except IOError as e: except IOError as e:
@ -581,7 +612,7 @@ class RoomMemberHandler(object):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def verify_any_signature(self, data, server_hostname): def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]: if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,)) raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items(): for key_name, signature in data["signatures"][server_hostname].items():
@ -735,20 +766,16 @@ class RoomMemberHandler(object):
} }
if self.config.invite_3pid_guest: if self.config.invite_3pid_guest:
registration_handler = self.registration_handler rh = self.registration_handler
guest_access_token = yield registration_handler.guest_access_token_for( guest_user_id, guest_access_token = yield rh.get_or_register_3pid_guest(
medium=medium, medium=medium,
address=address, address=address,
inviter_user_id=inviter_user_id, inviter_user_id=inviter_user_id,
) )
guest_user_info = yield self.auth.get_user_by_access_token(
guest_access_token
)
invite_config.update({ invite_config.update({
"guest_access_token": guest_access_token, "guest_access_token": guest_access_token,
"guest_user_id": guest_user_info["user"].to_string(), "guest_user_id": guest_user_id,
}) })
data = yield self.simple_http_client.post_urlencoded_get_json( data = yield self.simple_http_client.post_urlencoded_get_json(

View File

@ -56,7 +56,7 @@ class TypingHandler(object):
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu) hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
hs.get_distributor().observe("user_left_room", self.user_left_room) hs.get_distributor().observe("user_left_room", self.user_left_room)

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -59,6 +60,11 @@ response_count = metrics.register_counter(
) )
) )
requests_counter = metrics.register_counter(
"requests_received",
labels=["method", "servlet", ],
)
outgoing_responses_counter = metrics.register_counter( outgoing_responses_counter = metrics.register_counter(
"responses", "responses",
labels=["method", "code"], labels=["method", "code"],
@ -145,7 +151,8 @@ def wrap_request_handler(request_handler, include_metrics=False):
# at the servlet name. For most requests that name will be # at the servlet name. For most requests that name will be
# JsonResource (or a subclass), and JsonResource._async_render # JsonResource (or a subclass), and JsonResource._async_render
# will update it once it picks a servlet. # will update it once it picks a servlet.
request_metrics.start(self.clock, name=self.__class__.__name__) servlet_name = self.__class__.__name__
request_metrics.start(self.clock, name=servlet_name)
request_context.request = request_id request_context.request = request_id
with request.processing(): with request.processing():
@ -154,6 +161,7 @@ def wrap_request_handler(request_handler, include_metrics=False):
if include_metrics: if include_metrics:
yield request_handler(self, request, request_metrics) yield request_handler(self, request, request_metrics)
else: else:
requests_counter.inc(request.method, servlet_name)
yield request_handler(self, request) yield request_handler(self, request)
except CodeMessageException as e: except CodeMessageException as e:
code = e.code code = e.code
@ -229,7 +237,7 @@ class JsonResource(HttpServer, resource.Resource):
""" This implements the HttpServer interface and provides JSON support for """ This implements the HttpServer interface and provides JSON support for
Resources. Resources.
Register callbacks via register_path() Register callbacks via register_paths()
Callbacks can return a tuple of status code and a dict in which case the Callbacks can return a tuple of status code and a dict in which case the
the dict will automatically be sent to the client as a JSON object. the dict will automatically be sent to the client as a JSON object.
@ -276,21 +284,7 @@ class JsonResource(HttpServer, resource.Resource):
This checks if anyone has registered a callback for that method and This checks if anyone has registered a callback for that method and
path. path.
""" """
if request.method == "OPTIONS": callback, group_dict = self._get_handler_for_request(request)
self._send_response(request, 200, {})
return
# Loop through all the registered callbacks to check if the method
# and path regex match
for path_entry in self.path_regexs.get(request.method, []):
m = path_entry.pattern.match(request.path)
if not m:
continue
# We found a match! First update the metrics object to indicate
# which servlet is handling the request.
callback = path_entry.callback
servlet_instance = getattr(callback, "__self__", None) servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None: if servlet_instance is not None:
@ -299,6 +293,7 @@ class JsonResource(HttpServer, resource.Resource):
servlet_classname = "%r" % callback servlet_classname = "%r" % callback
request_metrics.name = servlet_classname request_metrics.name = servlet_classname
requests_counter.inc(request.method, servlet_classname)
# Now trigger the callback. If it returns a response, we send it # Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper # here. If it throws an exception, that is handled by the wrapper
@ -306,7 +301,7 @@ class JsonResource(HttpServer, resource.Resource):
kwargs = intern_dict({ kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value name: urllib.unquote(value).decode("UTF-8") if value else value
for name, value in m.groupdict().items() for name, value in group_dict.items()
}) })
callback_return = yield callback(request, **kwargs) callback_return = yield callback(request, **kwargs)
@ -314,11 +309,34 @@ class JsonResource(HttpServer, resource.Resource):
code, response = callback_return code, response = callback_return
self._send_response(request, code, response) self._send_response(request, code, response)
return def _get_handler_for_request(self, request):
"""Finds a callback method to handle the given request
Args:
request (twisted.web.http.Request):
Returns:
Tuple[Callable, dict[str, str]]: callback method, and the dict
mapping keys to path components as specified in the handler's
path match regexp.
The callback will normally be a method registered via
register_paths, so will return (possibly via Deferred) either
None, or a tuple of (http code, response body).
"""
if request.method == "OPTIONS":
return _options_handler, {}
# Loop through all the registered callbacks to check if the method
# and path regex match
for path_entry in self.path_regexs.get(request.method, []):
m = path_entry.pattern.match(request.path)
if m:
# We found a match!
return path_entry.callback, m.groupdict()
# Huh. No one wanted to handle that? Fiiiiiine. Send 400. # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
request_metrics.name = self.__class__.__name__ + ".UnrecognizedRequest" return _unrecognised_request_handler, {}
raise UnrecognizedRequestError()
def _send_response(self, request, code, response_json_object, def _send_response(self, request, code, response_json_object,
response_code_message=None): response_code_message=None):
@ -335,6 +353,34 @@ class JsonResource(HttpServer, resource.Resource):
) )
def _options_handler(request):
"""Request handler for OPTIONS requests
This is a request handler suitable for return from
_get_handler_for_request. It returns a 200 and an empty body.
Args:
request (twisted.web.http.Request):
Returns:
Tuple[int, dict]: http code, response body.
"""
return 200, {}
def _unrecognised_request_handler(request):
"""Request handler for unrecognised requests
This is a request handler suitable for return from
_get_handler_for_request. It actually just raises an
UnrecognizedRequestError.
Args:
request (twisted.web.http.Request):
"""
raise UnrecognizedRequestError()
class RequestMetrics(object): class RequestMetrics(object):
def start(self, clock, name): def start(self, clock, name):
self.start = clock.time_msec() self.start = clock.time_msec()

View File

@ -57,15 +57,31 @@ class Metrics(object):
return metric return metric
def register_counter(self, *args, **kwargs): def register_counter(self, *args, **kwargs):
"""
Returns:
CounterMetric
"""
return self._register(CounterMetric, *args, **kwargs) return self._register(CounterMetric, *args, **kwargs)
def register_callback(self, *args, **kwargs): def register_callback(self, *args, **kwargs):
"""
Returns:
CallbackMetric
"""
return self._register(CallbackMetric, *args, **kwargs) return self._register(CallbackMetric, *args, **kwargs)
def register_distribution(self, *args, **kwargs): def register_distribution(self, *args, **kwargs):
"""
Returns:
DistributionMetric
"""
return self._register(DistributionMetric, *args, **kwargs) return self._register(DistributionMetric, *args, **kwargs)
def register_cache(self, *args, **kwargs): def register_cache(self, *args, **kwargs):
"""
Returns:
CacheMetric
"""
return self._register(CacheMetric, *args, **kwargs) return self._register(CacheMetric, *args, **kwargs)

View File

@ -25,7 +25,7 @@ from synapse.util.async import sleep
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.types import Requester from synapse.types import Requester, UserID
import logging import logging
import re import re
@ -46,7 +46,7 @@ def send_event_to_master(client, host, port, requester, event, context,
event (FrozenEvent) event (FrozenEvent)
context (EventContext) context (EventContext)
ratelimit (bool) ratelimit (bool)
extra_users (list(str)): Any extra users to notify about event extra_users (list(UserID)): Any extra users to notify about event
""" """
uri = "http://%s:%s/_synapse/replication/send_event/%s" % ( uri = "http://%s:%s/_synapse/replication/send_event/%s" % (
host, port, event.event_id, host, port, event.event_id,
@ -59,7 +59,7 @@ def send_event_to_master(client, host, port, requester, event, context,
"context": context.serialize(event), "context": context.serialize(event),
"requester": requester.serialize(), "requester": requester.serialize(),
"ratelimit": ratelimit, "ratelimit": ratelimit,
"extra_users": extra_users, "extra_users": [u.to_string() for u in extra_users],
} }
try: try:
@ -143,7 +143,7 @@ class ReplicationSendEventRestServlet(RestServlet):
context = yield EventContext.deserialize(self.store, content["context"]) context = yield EventContext.deserialize(self.store, content["context"])
ratelimit = content["ratelimit"] ratelimit = content["ratelimit"]
extra_users = content["extra_users"] extra_users = [UserID.from_string(u) for u in content["extra_users"]]
if requester.user: if requester.user:
request.authenticated_entity = requester.user.to_string() request.authenticated_entity = requester.user.to_string()

View File

@ -17,7 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.errors import AuthError, SynapseError, Codes from synapse.api.errors import AuthError, SynapseError, Codes, NotFoundError
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
@ -185,12 +185,43 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
errcode=Codes.BAD_JSON, errcode=Codes.BAD_JSON,
) )
yield self.handlers.message_handler.purge_history( purge_id = yield self.handlers.message_handler.start_purge_history(
room_id, depth, room_id, depth,
delete_local_events=delete_local_events, delete_local_events=delete_local_events,
) )
defer.returnValue((200, {})) defer.returnValue((200, {
"purge_id": purge_id,
}))
class PurgeHistoryStatusRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/admin/purge_history_status/(?P<purge_id>[^/]+)"
)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer)
"""
super(PurgeHistoryStatusRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, purge_id):
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
purge_status = self.handlers.message_handler.get_purge_status(purge_id)
if purge_status is None:
raise NotFoundError("purge id '%s' not found" % purge_id)
defer.returnValue((200, purge_status.asdict()))
class DeactivateAccountRestServlet(ClientV1RestServlet): class DeactivateAccountRestServlet(ClientV1RestServlet):
@ -561,6 +592,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server) WhoisRestServlet(hs).register(http_server)
PurgeMediaCacheRestServlet(hs).register(http_server) PurgeMediaCacheRestServlet(hs).register(http_server)
PurgeHistoryStatusRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server)
UsersRestServlet(hs).register(http_server) UsersRestServlet(hs).register(http_server)

View File

@ -599,7 +599,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/[invite|join|leave] # /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/" PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
"(?P<membership_action>join|invite|leave|ban|unban|kick|forget)") "(?P<membership_action>join|invite|leave|ban|unban|kick)")
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -32,8 +32,10 @@ from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.events.spamcheck import SpamChecker from synapse.events.spamcheck import SpamChecker
from synapse.federation import initialize_http_replication from synapse.federation.federation_client import FederationClient
from synapse.federation.federation_server import FederationServer
from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.send_queue import FederationRemoteSendQueue
from synapse.federation.federation_server import FederationHandlerRegistry
from synapse.federation.transport.client import TransportLayerClient from synapse.federation.transport.client import TransportLayerClient
from synapse.federation.transaction_queue import TransactionQueue from synapse.federation.transaction_queue import TransactionQueue
from synapse.handlers import Handlers from synapse.handlers import Handlers
@ -99,7 +101,8 @@ class HomeServer(object):
DEPENDENCIES = [ DEPENDENCIES = [
'http_client', 'http_client',
'db_pool', 'db_pool',
'replication_layer', 'federation_client',
'federation_server',
'handlers', 'handlers',
'v1auth', 'v1auth',
'auth', 'auth',
@ -147,6 +150,7 @@ class HomeServer(object):
'groups_attestation_renewer', 'groups_attestation_renewer',
'spam_checker', 'spam_checker',
'room_member_handler', 'room_member_handler',
'federation_registry',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -195,8 +199,11 @@ class HomeServer(object):
def get_ratelimiter(self): def get_ratelimiter(self):
return self.ratelimiter return self.ratelimiter
def build_replication_layer(self): def build_federation_client(self):
return initialize_http_replication(self) return FederationClient(self)
def build_federation_server(self):
return FederationServer(self)
def build_handlers(self): def build_handlers(self):
return Handlers(self) return Handlers(self)
@ -387,6 +394,9 @@ class HomeServer(object):
def build_room_member_handler(self): def build_room_member_handler(self):
return RoomMemberHandler(self) return RoomMemberHandler(self)
def build_federation_registry(self):
return FederationHandlerRegistry()
def remove_pusher(self, app_id, push_key, user_id): def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View File

@ -283,6 +283,7 @@ class EventsStore(EventsWorkerStore):
def _maybe_start_persisting(self, room_id): def _maybe_start_persisting(self, room_id):
@defer.inlineCallbacks @defer.inlineCallbacks
def persisting_queue(item): def persisting_queue(item):
with Measure(self._clock, "persist_events"):
yield self._persist_events( yield self._persist_events(
item.events_and_contexts, item.events_and_contexts,
backfilled=item.backfilled, backfilled=item.backfilled,

View File

@ -245,8 +245,11 @@ class StateGroupWorkerStore(SQLBaseStore):
if types: if types:
clause_to_args = [ clause_to_args = [
( (
"AND type = ? AND state_key = ?" if state_key is not None else "AND type = ?", "AND type = ? AND state_key = ?",
(etype, state_key) if state_key is not None else (etype) (etype, state_key)
) if state_key is not None else (
"AND type = ?",
(etype,)
) )
for etype, state_key in types for etype, state_key in types
] ]
@ -277,22 +280,25 @@ class StateGroupWorkerStore(SQLBaseStore):
results[group][key] = event_id results[group][key] = event_id
else: else:
where_args = [] where_args = []
where_clauses = []
wildcard_types = False
if types is not None: if types is not None:
where_clause = "AND ("
for typ in types: for typ in types:
if typ[1] is None: if typ[1] is None:
where_clause += "(type = ?) OR " where_clauses.append("(type = ?)")
where_args.extend(typ[0]) where_args.extend(typ[0])
wildcard_types = True
else: else:
where_clause += "(type = ? AND state_key = ?) OR " where_clauses.append("(type = ? AND state_key = ?)")
where_args.extend([typ[0], typ[1]]) where_args.extend([typ[0], typ[1]])
if include_other_types: if include_other_types:
where_clause += "(%s) OR " % ( where_clauses.append(
" AND ".join(["type <> ?"] * len(types)), "(" + " AND ".join(["type <> ?"] * len(types)) + ")"
) )
where_args.extend(t for (t, _) in types) where_args.extend(t for (t, _) in types)
where_clause += "0)" # 0 to terminate the last OR
where_clause = "AND (%s)" % (" OR ".join(where_clauses))
else: else:
where_clause = "" where_clause = ""
@ -322,9 +328,17 @@ class StateGroupWorkerStore(SQLBaseStore):
if (typ, state_key) not in results[group] if (typ, state_key) not in results[group]
) )
# If the lengths match then we must have all the types, # If the number of entries in the (type,state_key)->event_id dict
# so no need to go walk further down the tree. # matches the number of (type,state_keys) types we were searching
if types is not None and len(results[group]) == len(types): # for, then we must have found them all, so no need to go walk
# further down the tree... UNLESS our types filter contained
# wildcards (i.e. Nones) in which case we have to do an exhaustive
# search
if (
types is not None and
not wildcard_types and
len(results[group]) == len(types)
):
break break
next_group = self._simple_select_one_onecol_txn( next_group = self._simple_select_one_onecol_txn(

View File

@ -292,14 +292,20 @@ class PreserveLoggingContext(object):
def preserve_fn(f): def preserve_fn(f):
"""Wraps a function, to ensure that the current context is restored after """Function decorator which wraps the function with run_in_background"""
def g(*args, **kwargs):
return run_in_background(f, *args, **kwargs)
return g
def run_in_background(f, *args, **kwargs):
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the return from the function, and that the sentinel context is set once the
deferred returned by the funtion completes. deferred returned by the funtion completes.
Useful for wrapping functions that return a deferred which you don't yield Useful for wrapping functions that return a deferred which you don't yield
on. on.
""" """
def g(*args, **kwargs):
current = LoggingContext.current_context() current = LoggingContext.current_context()
res = f(*args, **kwargs) res = f(*args, **kwargs)
if isinstance(res, defer.Deferred) and not res.called: if isinstance(res, defer.Deferred) and not res.called:
@ -321,7 +327,6 @@ def preserve_fn(f):
# adding a new exit point.) # adding a new exit point.)
res.addBoth(_set_context_cb, LoggingContext.sentinel) res.addBoth(_set_context_cb, LoggingContext.sentinel)
return res return res
return g
def make_deferred_yieldable(deferred): def make_deferred_yieldable(deferred):

View File

@ -35,21 +35,20 @@ class DirectoryTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.mock_federation = Mock(spec=[ self.mock_federation = Mock()
"make_query", self.mock_registry = Mock()
"register_edu_handler",
])
self.query_handlers = {} self.query_handlers = {}
def register_query_handler(query_type, handler): def register_query_handler(query_type, handler):
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
self.mock_federation.register_query_handler = register_query_handler self.mock_registry.register_query_handler = register_query_handler
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
http_client=None, http_client=None,
resource_for_federation=Mock(), resource_for_federation=Mock(),
replication_layer=self.mock_federation, federation_client=self.mock_federation,
federation_registry=self.mock_registry,
) )
hs.handlers = DirectoryHandlers(hs) hs.handlers = DirectoryHandlers(hs)

View File

@ -34,7 +34,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.hs = yield utils.setup_test_homeserver( self.hs = yield utils.setup_test_homeserver(
handlers=None, handlers=None,
replication_layer=mock.Mock(), federation_client=mock.Mock(),
) )
self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)

View File

@ -37,23 +37,23 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.mock_federation = Mock(spec=[ self.mock_federation = Mock()
"make_query", self.mock_registry = Mock()
"register_edu_handler",
])
self.query_handlers = {} self.query_handlers = {}
def register_query_handler(query_type, handler): def register_query_handler(query_type, handler):
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
self.mock_federation.register_query_handler = register_query_handler self.mock_registry.register_query_handler = register_query_handler
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
http_client=None, http_client=None,
handlers=None, handlers=None,
resource_for_federation=Mock(), resource_for_federation=Mock(),
replication_layer=self.mock_federation, federation_client=self.mock_federation,
federation_server=Mock(),
federation_registry=self.mock_registry,
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]) ])

View File

@ -81,7 +81,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
"get_current_state_deltas", "get_current_state_deltas",
]), ]),
state_handler=self.state_handler, state_handler=self.state_handler,
handlers=None, handlers=Mock(),
notifier=mock_notifier, notifier=mock_notifier,
resource_for_client=Mock(), resource_for_client=Mock(),
resource_for_federation=self.mock_federation_resource, resource_for_federation=self.mock_federation_resource,

View File

@ -31,7 +31,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
self.hs = yield setup_test_homeserver( self.hs = yield setup_test_homeserver(
"blue", "blue",
http_client=None, http_client=None,
replication_layer=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),

View File

@ -114,7 +114,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
http_client=None, http_client=None,
replication_layer=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),

View File

@ -45,7 +45,7 @@ class ProfileTestCase(unittest.TestCase):
http_client=None, http_client=None,
resource_for_client=self.mock_resource, resource_for_client=self.mock_resource,
federation=Mock(), federation=Mock(),
replication_layer=Mock(), federation_client=Mock(),
profile_handler=self.mock_handler profile_handler=self.mock_handler
) )

View File

@ -46,7 +46,7 @@ class RoomPermissionsTestCase(RestTestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
"red", "red",
http_client=None, http_client=None,
replication_layer=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]), ratelimiter=NonCallableMock(spec_set=["send_message"]),
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
@ -409,7 +409,7 @@ class RoomsMemberListTestCase(RestTestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
"red", "red",
http_client=None, http_client=None,
replication_layer=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]), ratelimiter=NonCallableMock(spec_set=["send_message"]),
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
@ -493,7 +493,7 @@ class RoomsCreateTestCase(RestTestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
"red", "red",
http_client=None, http_client=None,
replication_layer=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]), ratelimiter=NonCallableMock(spec_set=["send_message"]),
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
@ -582,7 +582,7 @@ class RoomTopicTestCase(RestTestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
"red", "red",
http_client=None, http_client=None,
replication_layer=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]), ratelimiter=NonCallableMock(spec_set=["send_message"]),
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
@ -697,7 +697,7 @@ class RoomMemberStateTestCase(RestTestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
"red", "red",
http_client=None, http_client=None,
replication_layer=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]), ratelimiter=NonCallableMock(spec_set=["send_message"]),
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
@ -829,7 +829,7 @@ class RoomMessagesTestCase(RestTestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
"red", "red",
http_client=None, http_client=None,
replication_layer=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]), ratelimiter=NonCallableMock(spec_set=["send_message"]),
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
@ -929,7 +929,7 @@ class RoomInitialSyncTestCase(RestTestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
"red", "red",
http_client=None, http_client=None,
replication_layer=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),
@ -1003,7 +1003,7 @@ class RoomMessageListTestCase(RestTestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
"red", "red",
http_client=None, http_client=None,
replication_layer=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]), ratelimiter=NonCallableMock(spec_set=["send_message"]),
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()

View File

@ -47,7 +47,7 @@ class RoomTypingTestCase(RestTestCase):
"red", "red",
clock=self.clock, clock=self.clock,
http_client=None, http_client=None,
replication_layer=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),

View File

@ -42,7 +42,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
config=config, config=config,
federation_sender=Mock(), federation_sender=Mock(),
replication_layer=Mock(), federation_client=Mock(),
) )
self.as_token = "token1" self.as_token = "token1"
@ -119,7 +119,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
config=config, config=config,
federation_sender=Mock(), federation_sender=Mock(),
replication_layer=Mock(), federation_client=Mock(),
) )
self.db_pool = hs.get_db_pool() self.db_pool = hs.get_db_pool()
@ -455,7 +455,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
config=config, config=config,
datastore=Mock(), datastore=Mock(),
federation_sender=Mock(), federation_sender=Mock(),
replication_layer=Mock(), federation_client=Mock(),
) )
ApplicationServiceStore(None, hs) ApplicationServiceStore(None, hs)
@ -473,7 +473,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
config=config, config=config,
datastore=Mock(), datastore=Mock(),
federation_sender=Mock(), federation_sender=Mock(),
replication_layer=Mock(), federation_client=Mock(),
) )
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
@ -497,7 +497,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
config=config, config=config,
datastore=Mock(), datastore=Mock(),
federation_sender=Mock(), federation_sender=Mock(),
replication_layer=Mock(), federation_client=Mock(),
) )
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm: