Fix logcontexts in _check_sigs_and_hashes

This commit is contained in:
Richard van der Hoff 2017-09-20 01:32:42 +01:00
parent 72472456d8
commit 6de74ea6d7
2 changed files with 59 additions and 57 deletions

View File

@ -18,8 +18,7 @@ from synapse.api.errors import SynapseError
from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import spamcheck from synapse.events import spamcheck
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError, logcontext
from synapse.util.logcontext import preserve_context_over_deferred
from twisted.internet import defer from twisted.internet import defer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -51,56 +50,52 @@ class FederationBase(object):
""" """
deferreds = self._check_sigs_and_hashes(pdus) deferreds = self._check_sigs_and_hashes(pdus)
def callback(pdu): @defer.inlineCallbacks
return pdu def handle_check_result(pdu, deferred):
try:
res = yield logcontext.make_deferred_yieldable(deferred)
except SynapseError:
res = None
def errback(failure, pdu):
failure.trap(SynapseError)
return None
def try_local_db(res, pdu):
if not res: if not res:
# Check local db. # Check local db.
return self.store.get_event( res = yield self.store.get_event(
pdu.event_id, pdu.event_id,
allow_rejected=True, allow_rejected=True,
allow_none=True, allow_none=True,
) )
return res
def try_remote(res, pdu):
if not res and pdu.origin != origin: if not res and pdu.origin != origin:
return self.get_pdu( try:
destinations=[pdu.origin], res = yield self.get_pdu(
event_id=pdu.event_id, destinations=[pdu.origin],
outlier=outlier, event_id=pdu.event_id,
timeout=10000, outlier=outlier,
).addErrback(lambda e: None) timeout=10000,
return res )
except SynapseError:
pass
def warn(res, pdu):
if not res: if not res:
logger.warn( logger.warn(
"Failed to find copy of %s with valid signature", "Failed to find copy of %s with valid signature",
pdu.event_id, pdu.event_id,
) )
return res
for pdu, deferred in zip(pdus, deferreds): defer.returnValue(res)
deferred.addCallbacks(
callback, errback, errbackArgs=[pdu] handle = logcontext.preserve_fn(handle_check_result)
).addCallback( deferreds2 = [
try_local_db, pdu handle(pdu, deferred)
).addCallback( for pdu, deferred in zip(pdus, deferreds)
try_remote, pdu ]
).addCallback(
warn, pdu valid_pdus = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
deferreds2,
consumeErrors=True,
) )
).addErrback(unwrapFirstError)
valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
deferreds,
consumeErrors=True
)).addErrback(unwrapFirstError)
if include_none: if include_none:
defer.returnValue(valid_pdus) defer.returnValue(valid_pdus)
@ -108,7 +103,9 @@ class FederationBase(object):
defer.returnValue([p for p in valid_pdus if p]) defer.returnValue([p for p in valid_pdus if p])
def _check_sigs_and_hash(self, pdu): def _check_sigs_and_hash(self, pdu):
return self._check_sigs_and_hashes([pdu])[0] return logcontext.make_deferred_yieldable(
self._check_sigs_and_hashes([pdu])[0],
)
def _check_sigs_and_hashes(self, pdus): def _check_sigs_and_hashes(self, pdus):
"""Checks that each of the received events is correctly signed by the """Checks that each of the received events is correctly signed by the
@ -123,6 +120,7 @@ class FederationBase(object):
* returns a redacted version of the event (if the signature * returns a redacted version of the event (if the signature
matched but the hash did not) matched but the hash did not)
* throws a SynapseError if the signature check failed. * throws a SynapseError if the signature check failed.
The deferreds run their callbacks in the sentinel logcontext.
""" """
redacted_pdus = [ redacted_pdus = [
@ -135,29 +133,33 @@ class FederationBase(object):
for p in redacted_pdus for p in redacted_pdus
]) ])
ctx = logcontext.LoggingContext.current_context()
def callback(_, pdu, redacted): def callback(_, pdu, redacted):
if not check_event_content_hash(pdu): with logcontext.PreserveLoggingContext(ctx):
logger.warn( if not check_event_content_hash(pdu):
"Event content has been tampered, redacting %s: %s", logger.warn(
pdu.event_id, pdu.get_pdu_json() "Event content has been tampered, redacting %s: %s",
) pdu.event_id, pdu.get_pdu_json()
return redacted )
return redacted
if spamcheck.check_event_for_spam(pdu): if spamcheck.check_event_for_spam(pdu):
logger.warn( logger.warn(
"Event contains spam, redacting %s: %s", "Event contains spam, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json() pdu.event_id, pdu.get_pdu_json()
) )
return redacted return redacted
return pdu return pdu
def errback(failure, pdu): def errback(failure, pdu):
failure.trap(SynapseError) failure.trap(SynapseError)
logger.warn( with logcontext.PreserveLoggingContext(ctx):
"Signature check failed for %s", logger.warn(
pdu.event_id, "Signature check failed for %s",
) pdu.event_id,
)
return failure return failure
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus): for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):

View File

@ -22,7 +22,7 @@ from synapse.api.constants import Membership
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError, CodeMessageException, HttpResponseException, SynapseError,
) )
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@ -189,10 +189,10 @@ class FederationClient(FederationBase):
] ]
# FIXME: We should handle signature failures more gracefully. # FIXME: We should handle signature failures more gracefully.
pdus[:] = yield preserve_context_over_deferred(defer.gatherResults( pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
self._check_sigs_and_hashes(pdus), self._check_sigs_and_hashes(pdus),
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError))
defer.returnValue(pdus) defer.returnValue(pdus)
@ -252,7 +252,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0] pdu = pdu_list[0]
# Check signatures are correct. # Check signatures are correct.
signed_pdu = yield self._check_sigs_and_hashes([pdu])[0] signed_pdu = yield self._check_sigs_and_hash(pdu)
break break