Merge branch 'develop' into dbkr/channel_notifications

This commit is contained in:
David Baker 2017-10-10 11:20:17 +01:00
commit 89fa00ddff
7 changed files with 183 additions and 107 deletions

View File

@ -46,7 +46,7 @@ class SpamChecker(object):
return self.spam_checker.check_event_for_spam(event) return self.spam_checker.check_event_for_spam(event)
def user_may_invite(self, userid, room_id): def user_may_invite(self, inviter_userid, invitee_userid, room_id):
"""Checks if a given user may send an invite """Checks if a given user may send an invite
If this method returns false, the invite will be rejected. If this method returns false, the invite will be rejected.
@ -60,7 +60,7 @@ class SpamChecker(object):
if self.spam_checker is None: if self.spam_checker is None:
return True return True
return self.spam_checker.user_may_invite(userid, room_id) return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
def user_may_create_room(self, userid): def user_may_create_room(self, userid):
"""Checks if a given user may create a room """Checks if a given user may create a room

View File

@ -12,14 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from .federation_base import FederationBase from .federation_base import FederationBase
from .units import Transaction, Edu from .units import Transaction, Edu
from synapse.util.async import Linearizer from synapse.util import async
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
@ -33,6 +31,9 @@ from synapse.crypto.event_signing import compute_event_signature
import simplejson as json import simplejson as json
import logging import logging
# when processing incoming transactions, we try to handle multiple rooms in
# parallel, up to this limit.
TRANSACTION_CONCURRENCY_LIMIT = 10
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -52,7 +53,8 @@ class FederationServer(FederationBase):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._server_linearizer = Linearizer("fed_server") self._server_linearizer = async.Linearizer("fed_server")
self._transaction_linearizer = async.Linearizer("fed_txn_handler")
# We cache responses to state queries, as they take a while and often # We cache responses to state queries, as they take a while and often
# come in waves. # come in waves.
@ -109,25 +111,41 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_incoming_transaction(self, transaction_data): def on_incoming_transaction(self, transaction_data):
# keep this as early as possible to make the calculated origin ts as
# accurate as possible.
request_time = self._clock.time_msec()
transaction = Transaction(**transaction_data) transaction = Transaction(**transaction_data)
received_pdus_counter.inc_by(len(transaction.pdus)) if not transaction.transaction_id:
raise Exception("Transaction missing transaction_id")
for p in transaction.pdus: if not transaction.origin:
if "unsigned" in p: raise Exception("Transaction missing origin")
unsigned = p["unsigned"]
if "age" in unsigned:
p["age"] = unsigned["age"]
if "age" in p:
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
del p["age"]
pdu_list = [
self.event_from_pdu_json(p) for p in transaction.pdus
]
logger.debug("[%s] Got transaction", transaction.transaction_id) logger.debug("[%s] Got transaction", transaction.transaction_id)
# use a linearizer to ensure that we don't process the same transaction
# multiple times in parallel.
with (yield self._transaction_linearizer.queue(
(transaction.origin, transaction.transaction_id),
)):
result = yield self._handle_incoming_transaction(
transaction, request_time,
)
defer.returnValue(result)
@defer.inlineCallbacks
def _handle_incoming_transaction(self, transaction, request_time):
""" Process an incoming transaction and return the HTTP response
Args:
transaction (Transaction): incoming transaction
request_time (int): timestamp that the HTTP request arrived at
Returns:
Deferred[(int, object)]: http response code and body
"""
response = yield self.transaction_actions.have_responded(transaction) response = yield self.transaction_actions.have_responded(transaction)
if response: if response:
@ -140,42 +158,50 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id) logger.debug("[%s] Transaction is new", transaction.transaction_id)
results = [] received_pdus_counter.inc_by(len(transaction.pdus))
for pdu in pdu_list: pdus_by_room = {}
# check that it's actually being sent from a valid destination to
# workaround bug #1753 in 0.18.5 and 0.18.6
if transaction.origin != get_domain_from_id(pdu.event_id):
# We continue to accept join events from any server; this is
# necessary for the federation join dance to work correctly.
# (When we join over federation, the "helper" server is
# responsible for sending out the join event, rather than the
# origin. See bug #1893).
if not (
pdu.type == 'm.room.member' and
pdu.content and
pdu.content.get("membership", None) == 'join'
):
logger.info(
"Discarding PDU %s from invalid origin %s",
pdu.event_id, transaction.origin
)
continue
else:
logger.info(
"Accepting join PDU %s from %s",
pdu.event_id, transaction.origin
)
for p in transaction.pdus:
if "unsigned" in p:
unsigned = p["unsigned"]
if "age" in unsigned:
p["age"] = unsigned["age"]
if "age" in p:
p["age_ts"] = request_time - int(p["age"])
del p["age"]
event = self.event_from_pdu_json(p)
room_id = event.room_id
pdus_by_room.setdefault(room_id, []).append(event)
pdu_results = {}
# we can process different rooms in parallel (which is useful if they
# require callouts to other servers to fetch missing events), but
# impose a limit to avoid going too crazy with ram/cpu.
@defer.inlineCallbacks
def process_pdus_for_room(room_id):
logger.debug("Processing PDUs for %s", room_id)
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
try: try:
yield self._handle_received_pdu(transaction.origin, pdu) yield self._handle_received_pdu(
results.append({}) transaction.origin, pdu
)
pdu_results[event_id] = {}
except FederationError as e: except FederationError as e:
logger.warn("Error handling PDU %s: %s", event_id, e)
self.send_failure(e, transaction.origin) self.send_failure(e, transaction.origin)
results.append({"error": str(e)}) pdu_results[event_id] = {"error": str(e)}
except Exception as e: except Exception as e:
results.append({"error": str(e)}) pdu_results[event_id] = {"error": str(e)}
logger.exception("Failed to handle PDU") logger.exception("Failed to handle PDU %s", event_id)
yield async.concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(),
TRANSACTION_CONCURRENCY_LIMIT,
)
if hasattr(transaction, "edus"): if hasattr(transaction, "edus"):
for edu in (Edu(**x) for x in transaction.edus): for edu in (Edu(**x) for x in transaction.edus):
@ -188,14 +214,12 @@ class FederationServer(FederationBase):
for failure in getattr(transaction, "pdu_failures", []): for failure in getattr(transaction, "pdu_failures", []):
logger.info("Got failure %r", failure) logger.info("Got failure %r", failure)
logger.debug("Returning: %s", str(results))
response = { response = {
"pdus": dict(zip( "pdus": pdu_results,
(p.event_id for p in pdu_list), results
)),
} }
logger.debug("Returning: %s", str(response))
yield self.transaction_actions.set_response( yield self.transaction_actions.set_response(
transaction, transaction,
200, response 200, response
@ -520,6 +544,30 @@ class FederationServer(FederationBase):
Returns (Deferred): completes with None Returns (Deferred): completes with None
Raises: FederationError if the signatures / hash do not match Raises: FederationError if the signatures / hash do not match
""" """
# check that it's actually being sent from a valid destination to
# workaround bug #1753 in 0.18.5 and 0.18.6
if origin != get_domain_from_id(pdu.event_id):
# We continue to accept join events from any server; this is
# necessary for the federation join dance to work correctly.
# (When we join over federation, the "helper" server is
# responsible for sending out the join event, rather than the
# origin. See bug #1893).
if not (
pdu.type == 'm.room.member' and
pdu.content and
pdu.content.get("membership", None) == 'join'
):
logger.info(
"Discarding PDU %s from invalid origin %s",
pdu.event_id, origin
)
return
else:
logger.info(
"Accepting join PDU %s from %s",
pdu.event_id, origin
)
# Check signature. # Check signature.
try: try:
pdu = yield self._check_sigs_and_hash(pdu) pdu = yield self._check_sigs_and_hash(pdu)

View File

@ -20,8 +20,8 @@ from .persistence import TransactionActions
from .units import Transaction, Edu from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util import logcontext
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn, preserve_fn
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
@ -231,11 +231,9 @@ class TransactionQueue(object):
(pdu, order) (pdu, order)
) )
preserve_context_over_fn( self._attempt_new_transaction(destination)
self._attempt_new_transaction, destination
)
@preserve_fn # the caller should not yield on this @logcontext.preserve_fn # the caller should not yield on this
@defer.inlineCallbacks @defer.inlineCallbacks
def send_presence(self, states): def send_presence(self, states):
"""Send the new presence states to the appropriate destinations. """Send the new presence states to the appropriate destinations.
@ -299,7 +297,7 @@ class TransactionQueue(object):
state.user_id: state for state in states state.user_id: state for state in states
}) })
preserve_fn(self._attempt_new_transaction)(destination) self._attempt_new_transaction(destination)
def send_edu(self, destination, edu_type, content, key=None): def send_edu(self, destination, edu_type, content, key=None):
edu = Edu( edu = Edu(
@ -321,9 +319,7 @@ class TransactionQueue(object):
else: else:
self.pending_edus_by_dest.setdefault(destination, []).append(edu) self.pending_edus_by_dest.setdefault(destination, []).append(edu)
preserve_context_over_fn( self._attempt_new_transaction(destination)
self._attempt_new_transaction, destination
)
def send_failure(self, failure, destination): def send_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost": if destination == self.server_name or destination == "localhost":
@ -336,9 +332,7 @@ class TransactionQueue(object):
destination, [] destination, []
).append(failure) ).append(failure)
preserve_context_over_fn( self._attempt_new_transaction(destination)
self._attempt_new_transaction, destination
)
def send_device_messages(self, destination): def send_device_messages(self, destination):
if destination == self.server_name or destination == "localhost": if destination == self.server_name or destination == "localhost":
@ -347,15 +341,24 @@ class TransactionQueue(object):
if not self.can_send_to(destination): if not self.can_send_to(destination):
return return
preserve_context_over_fn( self._attempt_new_transaction(destination)
self._attempt_new_transaction, destination
)
def get_current_token(self): def get_current_token(self):
return 0 return 0
@defer.inlineCallbacks
def _attempt_new_transaction(self, destination): def _attempt_new_transaction(self, destination):
"""Try to start a new transaction to this destination
If there is already a transaction in progress to this destination,
returns immediately. Otherwise kicks off the process of sending a
transaction in the background.
Args:
destination (str):
Returns:
None
"""
# list of (pending_pdu, deferred, order) # list of (pending_pdu, deferred, order)
if destination in self.pending_transactions: if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending # XXX: pending_transactions can get stuck on by a never-ending
@ -368,6 +371,19 @@ class TransactionQueue(object):
) )
return return
logger.debug("TX [%s] Starting transaction loop", destination)
# Drop the logcontext before starting the transaction. It doesn't
# really make sense to log all the outbound transactions against
# whatever path led us to this point: that's pretty arbitrary really.
#
# (this also means we can fire off _perform_transaction without
# yielding)
with logcontext.PreserveLoggingContext():
self._transaction_transmission_loop(destination)
@defer.inlineCallbacks
def _transaction_transmission_loop(self, destination):
pending_pdus = [] pending_pdus = []
try: try:
self.pending_transactions[destination] = 1 self.pending_transactions[destination] = 1

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Contains handlers for federation events.""" """Contains handlers for federation events."""
import synapse.util.logcontext
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
@ -26,10 +25,7 @@ from synapse.api.errors import (
) )
from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError, logcontext
from synapse.util.logcontext import (
preserve_fn, preserve_context_over_deferred
)
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor, Linearizer from synapse.util.async import run_on_reactor, Linearizer
@ -592,9 +588,9 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch missing_auth - failed_to_fetch
) )
results = yield preserve_context_over_deferred(defer.gatherResults( results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(self.replication_layer.get_pdu)( logcontext.preserve_fn(self.replication_layer.get_pdu)(
[dest], [dest],
event_id, event_id,
outlier=True, outlier=True,
@ -786,10 +782,14 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys()) event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill") logger.debug("calling resolve_state_groups in _maybe_backfill")
states = yield preserve_context_over_deferred(defer.gatherResults([ states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) [
logcontext.preserve_fn(self.state_handler.resolve_state_groups)(
room_id, [e]
)
for e in event_ids for e in event_ids
])) ], consumeErrors=True,
))
states = dict(zip(event_ids, [s.state for s in states])) states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events( state_map = yield self.store.get_events(
@ -942,9 +942,7 @@ class FederationHandler(BaseHandler):
# lots of requests for missing prev_events which we do actually # lots of requests for missing prev_events which we do actually
# have. Hence we fire off the deferred, but don't wait for it. # have. Hence we fire off the deferred, but don't wait for it.
synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)( logcontext.preserve_fn(self._handle_queued_pdus)(room_queue)
room_queue
)
defer.returnValue(True) defer.returnValue(True)
@ -1071,6 +1069,9 @@ class FederationHandler(BaseHandler):
""" """
event = pdu event = pdu
if event.state_key is None:
raise SynapseError(400, "The invite event did not have a state key")
is_blocked = yield self.store.is_room_blocked(event.room_id) is_blocked = yield self.store.is_room_blocked(event.room_id)
if is_blocked: if is_blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
@ -1078,9 +1079,11 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites: if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites") raise SynapseError(403, "This server does not accept room invites")
if not self.spam_checker.user_may_invite(event.sender, event.room_id): if not self.spam_checker.user_may_invite(
event.sender, event.state_key, event.room_id,
):
raise SynapseError( raise SynapseError(
403, "This user is not permitted to send invites to this server" 403, "This user is not permitted to send invites to this server/user"
) )
membership = event.content.get("membership") membership = event.content.get("membership")
@ -1091,9 +1094,6 @@ class FederationHandler(BaseHandler):
if sender_domain != origin: if sender_domain != origin:
raise SynapseError(400, "The invite event was not from the server sending it") raise SynapseError(400, "The invite event was not from the server sending it")
if event.state_key is None:
raise SynapseError(400, "The invite event did not have a state key")
if not self.is_mine_id(event.state_key): if not self.is_mine_id(event.state_key):
raise SynapseError(400, "The invite event must be for this server") raise SynapseError(400, "The invite event must be for this server")
@ -1436,7 +1436,7 @@ class FederationHandler(BaseHandler):
if not backfilled: if not backfilled:
# this intentionally does not yield: we don't care about the result # this intentionally does not yield: we don't care about the result
# and don't need to wait for it. # and don't need to wait for it.
preserve_fn(self.pusher_pool.on_new_notifications)( logcontext.preserve_fn(self.pusher_pool.on_new_notifications)(
event_stream_id, max_stream_id event_stream_id, max_stream_id
) )
@ -1449,16 +1449,16 @@ class FederationHandler(BaseHandler):
a bunch of outliers, but not a chunk of individual events that depend a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations. on each other for state calculations.
""" """
contexts = yield preserve_context_over_deferred(defer.gatherResults( contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(self._prep_event)( logcontext.preserve_fn(self._prep_event)(
origin, origin,
ev_info["event"], ev_info["event"],
state=ev_info.get("state"), state=ev_info.get("state"),
auth_events=ev_info.get("auth_events"), auth_events=ev_info.get("auth_events"),
) )
for ev_info in event_infos for ev_info in event_infos
] ], consumeErrors=True,
)) ))
yield self.store.persist_events( yield self.store.persist_events(
@ -1766,18 +1766,17 @@ class FederationHandler(BaseHandler):
# Do auth conflict res. # Do auth conflict res.
logger.info("Different auth: %s", different_auth) logger.info("Different auth: %s", different_auth)
different_events = yield preserve_context_over_deferred(defer.gatherResults( different_events = yield logcontext.make_deferred_yieldable(
[ defer.gatherResults([
preserve_fn(self.store.get_event)( logcontext.preserve_fn(self.store.get_event)(
d, d,
allow_none=True, allow_none=True,
allow_rejected=False, allow_rejected=False,
) )
for d in different_auth for d in different_auth
if d in have_events and not have_events[d] if d in have_events and not have_events[d]
], ], consumeErrors=True)
consumeErrors=True ).addErrback(unwrapFirstError)
)).addErrback(unwrapFirstError)
if different_events: if different_events:
local_view = dict(auth_events) local_view = dict(auth_events)

View File

@ -225,7 +225,7 @@ class RoomMemberHandler(BaseHandler):
block_invite = True block_invite = True
if not self.spam_checker.user_may_invite( if not self.spam_checker.user_may_invite(
requester.user.to_string(), room_id, requester.user.to_string(), target.to_string(), room_id,
): ):
logger.info("Blocking invite due to spam checker") logger.info("Blocking invite due to spam checker")
block_invite = True block_invite = True

View File

@ -288,6 +288,9 @@ class StateHandler(object):
""" """
logger.debug("resolve_state_groups event_ids %s", event_ids) logger.debug("resolve_state_groups event_ids %s", event_ids)
# map from state group id to the state in that state group (where
# 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]]
state_groups_ids = yield self.store.get_state_groups_ids( state_groups_ids = yield self.store.get_state_groups_ids(
room_id, event_ids room_id, event_ids
) )
@ -320,11 +323,15 @@ class StateHandler(object):
"Resolving state for %s with %d groups", room_id, len(state_groups_ids) "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
) )
# build a map from state key to the event_ids which set that state.
# dict[(str, str), set[str])
state = {} state = {}
for st in state_groups_ids.values(): for st in state_groups_ids.values():
for key, e_id in st.items(): for key, e_id in st.items():
state.setdefault(key, set()).add(e_id) state.setdefault(key, set()).add(e_id)
# build a map from state key to the event_ids which set that state,
# including only those where there are state keys in conflict.
conflicted_state = { conflicted_state = {
k: list(v) k: list(v)
for k, v in state.items() for k, v in state.items()
@ -494,8 +501,14 @@ def _resolve_with_state_fac(unconflicted_state, conflicted_state,
logger.info("Asking for %d conflicted events", len(needed_events)) logger.info("Asking for %d conflicted events", len(needed_events))
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict.
state_map = yield state_map_factory(needed_events) state_map = yield state_map_factory(needed_events)
# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
#
# dict[(str, str), str]: a map from state key to event id
auth_events = _create_auth_events_from_maps( auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map unconflicted_state, conflicted_state, state_map
) )

View File

@ -19,7 +19,7 @@ from twisted.internet import defer, reactor
from .logcontext import ( from .logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
) )
from synapse.util import unwrapFirstError from synapse.util import logcontext, unwrapFirstError
from contextlib import contextmanager from contextlib import contextmanager
@ -155,7 +155,7 @@ def concurrently_execute(func, args, limit):
except StopIteration: except StopIteration:
pass pass
return preserve_context_over_deferred(defer.gatherResults([ return logcontext.make_deferred_yieldable(defer.gatherResults([
preserve_fn(_concurrently_execute_inner)() preserve_fn(_concurrently_execute_inner)()
for _ in xrange(limit) for _ in xrange(limit)
], consumeErrors=True)).addErrback(unwrapFirstError) ], consumeErrors=True)).addErrback(unwrapFirstError)