Fix up logcontexts

This commit is contained in:
Erik Johnston 2016-02-04 10:22:44 +00:00
parent 13e6262659
commit 2c1fbea531
31 changed files with 356 additions and 229 deletions

View File

@ -24,6 +24,7 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, RoomID, UserID, EventID from synapse.types import Requester, RoomID, UserID, EventID
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
import logging import logging
@ -529,7 +530,8 @@ class Auth(object):
default=[""] default=[""]
)[0] )[0]
if user and access_token and ip_addr: if user and access_token and ip_addr:
self.store.insert_client_ip( preserve_context_over_fn(
self.store.insert_client_ip,
user=user, user=user,
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,

View File

@ -709,6 +709,8 @@ def run(hs):
phone_home_task.start(60 * 60 * 24, now=False) phone_home_task.start(60 * 60 * 24, now=False)
def in_thread(): def in_thread():
# Uncomment to enable tracing of log context changes.
# sys.settrace(logcontext_tracer)
with LoggingContext("run"): with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit) change_resource_limit(hs.config.soft_file_limit)
reactor.run() reactor.run()

View File

@ -18,6 +18,10 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import (
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
preserve_fn
)
from twisted.internet import defer from twisted.internet import defer
@ -142,6 +146,8 @@ class Keyring(object):
for server_name, _ in server_and_json for server_name, _ in server_and_json
} }
with PreserveLoggingContext():
# We want to wait for any previous lookups to complete before # We want to wait for any previous lookups to complete before
# proceeding. # proceeding.
wait_on_deferred = self.wait_for_previous_lookups( wait_on_deferred = self.wait_for_previous_lookups(
@ -175,7 +181,8 @@ class Keyring(object):
# Pass those keys to handle_key_deferred so that the json object # Pass those keys to handle_key_deferred so that the json object
# signatures can be verified # signatures can be verified
return [ return [
handle_key_deferred( preserve_context_over_fn(
handle_key_deferred,
group_id_to_group[g_id], group_id_to_group[g_id],
deferreds[g_id], deferreds[g_id],
) )
@ -198,12 +205,13 @@ class Keyring(object):
if server_name in self.key_downloads if server_name in self.key_downloads
] ]
if wait_on: if wait_on:
with PreserveLoggingContext():
yield defer.DeferredList(wait_on) yield defer.DeferredList(wait_on)
else: else:
break break
for server_name, deferred in server_to_deferred.items(): for server_name, deferred in server_to_deferred.items():
d = ObservableDeferred(deferred) d = ObservableDeferred(preserve_context_over_deferred(deferred))
self.key_downloads[server_name] = d self.key_downloads[server_name] = d
def rm(r, server_name): def rm(r, server_name):
@ -244,6 +252,7 @@ class Keyring(object):
for group in group_id_to_group.values(): for group in group_id_to_group.values():
for key_id in group.key_ids: for key_id in group.key_ids:
if key_id in merged_results[group.server_name]: if key_id in merged_results[group.server_name]:
with PreserveLoggingContext():
group_id_to_deferred[group.group_id].callback(( group_id_to_deferred[group.group_id].callback((
group.group_id, group.group_id,
group.server_name, group.server_name,
@ -504,7 +513,7 @@ class Keyring(object):
yield defer.gatherResults( yield defer.gatherResults(
[ [
self.store_keys( preserve_fn(self.store_keys)(
server_name=key_server_name, server_name=key_server_name,
from_server=server_name, from_server=server_name,
verify_keys=verify_keys, verify_keys=verify_keys,
@ -573,7 +582,7 @@ class Keyring(object):
yield defer.gatherResults( yield defer.gatherResults(
[ [
self.store.store_server_keys_json( preserve_fn(self.store.store_server_keys_json)(
server_name=server_name, server_name=server_name,
key_id=key_id, key_id=key_id,
from_server=server_name, from_server=server_name,
@ -675,7 +684,7 @@ class Keyring(object):
# TODO(markjh): Store whether the keys have expired. # TODO(markjh): Store whether the keys have expired.
yield defer.gatherResults( yield defer.gatherResults(
[ [
self.store.store_server_verify_key( preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key server_name, server_name, key.time_added, key
) )
for key_id, key in verify_keys.items() for key_id, key in verify_keys.items()

View File

@ -126,10 +126,8 @@ class FederationServer(FederationBase):
results = [] results = []
for pdu in pdu_list: for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu)
try: try:
yield d yield self._handle_new_pdu(transaction.origin, pdu)
results.append({}) results.append({})
except FederationError as e: except FederationError as e:
self.send_failure(e, transaction.origin) self.send_failure(e, transaction.origin)

View File

@ -103,7 +103,6 @@ class TransactionQueue(object):
else: else:
return not destination.startswith("localhost") return not destination.startswith("localhost")
@defer.inlineCallbacks
def enqueue_pdu(self, pdu, destinations, order): def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have # We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus # a transaction in progress. If we do, stick it in the pending_pdus
@ -141,8 +140,6 @@ class TransactionQueue(object):
deferreds.append(deferred) deferreds.append(deferred)
yield defer.DeferredList(deferreds, consumeErrors=True)
# NO inlineCallbacks # NO inlineCallbacks
def enqueue_edu(self, edu): def enqueue_edu(self, edu):
destination = edu.destination destination = edu.destination

View File

@ -293,19 +293,11 @@ class BaseHandler(object):
with PreserveLoggingContext(): with PreserveLoggingContext():
# Don't block waiting on waking up all the listeners. # Don't block waiting on waking up all the listeners.
notify_d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=extra_users extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
notify_d.addErrback(log_failure)
# If invite, remove room_state from unsigned before sending. # If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None) event.unsigned.pop("invite_room_state", None)

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.types import UserID from synapse.types import UserID
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.util.logcontext import preserve_context_over_fn
from ._base import BaseHandler from ._base import BaseHandler
@ -29,11 +30,17 @@ logger = logging.getLogger(__name__)
def started_user_eventstream(distributor, user): def started_user_eventstream(distributor, user):
return distributor.fire("started_user_eventstream", user) return preserve_context_over_fn(
distributor.fire,
"started_user_eventstream", user
)
def stopped_user_eventstream(distributor, user): def stopped_user_eventstream(distributor, user):
return distributor.fire("stopped_user_eventstream", user) return preserve_context_over_fn(
distributor.fire,
"stopped_user_eventstream", user
)
class EventStreamHandler(BaseHandler): class EventStreamHandler(BaseHandler):

View File

@ -221,19 +221,11 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user) extra_users.append(target_user)
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=extra_users extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
prev_state = context.current_state.get((event.type, event.state_key)) prev_state = context.current_state.get((event.type, event.state_key))
@ -643,19 +635,11 @@ class FederationHandler(BaseHandler):
) )
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=[joinee] extra_users=[joinee]
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
logger.debug("Finished joining %s to %s", joinee, room_id) logger.debug("Finished joining %s to %s", joinee, room_id)
finally: finally:
room_queue = self.room_queues[room_id] room_queue = self.room_queues[room_id]
@ -730,18 +714,10 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user) extra_users.append(target_user)
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users event, event_stream_id, max_stream_id, extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN: if event.content["membership"] == Membership.JOIN:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
@ -811,19 +787,11 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(event.state_key) target_user = UserID.from_string(event.state_key)
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=[target_user], extra_users=[target_user],
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -948,18 +916,10 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user) extra_users.append(target_user)
with PreserveLoggingContext(): with PreserveLoggingContext():
d = self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users event, event_stream_id, max_stream_id, extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
new_pdu = event new_pdu = event
destinations = set() destinations = set()

View File

@ -378,9 +378,9 @@ class PresenceHandler(BaseHandler):
was_polling = target_user in self._user_cachemap was_polling = target_user in self._user_cachemap
if now_online and not was_polling: if now_online and not was_polling:
self.start_polling_presence(target_user, state=state) yield self.start_polling_presence(target_user, state=state)
elif not now_online and was_polling: elif not now_online and was_polling:
self.stop_polling_presence(target_user) yield self.stop_polling_presence(target_user)
# TODO(paul): perform a presence push as part of start/stop poll so # TODO(paul): perform a presence push as part of start/stop poll so
# we don't have to do this all the time # we don't have to do this all the time
@ -394,6 +394,7 @@ class PresenceHandler(BaseHandler):
if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY: if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY:
return return
with PreserveLoggingContext():
self.changed_presencelike_data(user, {"last_active": now}) self.changed_presencelike_data(user, {"last_active": now})
def get_joined_rooms_for_user(self, user): def get_joined_rooms_for_user(self, user):
@ -466,6 +467,7 @@ class PresenceHandler(BaseHandler):
local_user, room_ids=[room_id], add_to_cache=False local_user, room_ids=[room_id], add_to_cache=False
) )
with PreserveLoggingContext():
self.push_update_to_local_and_remote( self.push_update_to_local_and_remote(
observed_user=local_user, observed_user=local_user,
users_to_push=[user], users_to_push=[user],
@ -556,7 +558,7 @@ class PresenceHandler(BaseHandler):
observer_user.localpart, observed_user.to_string() observer_user.localpart, observed_user.to_string()
) )
self.start_polling_presence( yield self.start_polling_presence(
observer_user, target_user=observed_user observer_user, target_user=observed_user
) )

View File

@ -186,7 +186,7 @@ class RegistrationHandler(BaseHandler):
token=token, token=token,
password_hash="" password_hash=""
) )
registered_user(self.distributor, user) yield registered_user(self.distributor, user)
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -25,6 +25,7 @@ from synapse.api.constants import (
from synapse.api.errors import AuthError, StoreError, SynapseError, Codes from synapse.api.errors import AuthError, StoreError, SynapseError, Codes
from synapse.util import stringutils, unwrapFirstError from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn
from signedjson.sign import verify_signed_json from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -46,11 +47,17 @@ def collect_presencelike_data(distributor, user, content):
def user_left_room(distributor, user, room_id): def user_left_room(distributor, user, room_id):
return distributor.fire("user_left_room", user=user, room_id=room_id) return preserve_context_over_fn(
distributor.fire,
"user_left_room", user=user, room_id=room_id
)
def user_joined_room(distributor, user, room_id): def user_joined_room(distributor, user, room_id):
return distributor.fire("user_joined_room", user=user, room_id=room_id) return preserve_context_over_fn(
distributor.fire,
"user_joined_room", user=user, room_id=room_id
)
class RoomCreationHandler(BaseHandler): class RoomCreationHandler(BaseHandler):

View File

@ -18,7 +18,7 @@ from ._base import BaseHandler
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from twisted.internet import defer from twisted.internet import defer
@ -241,6 +241,7 @@ class SyncHandler(BaseHandler):
deferreds = [] deferreds = []
for event in room_list: for event in room_list:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
with PreserveLoggingContext(LoggingContext.current_context()):
room_sync_deferred = self.full_state_sync_for_joined_room( room_sync_deferred = self.full_state_sync_for_joined_room(
room_id=event.room_id, room_id=event.room_id,
sync_config=sync_config, sync_config=sync_config,
@ -262,6 +263,7 @@ class SyncHandler(BaseHandler):
leave_token = now_token.copy_and_replace( leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,) "room_key", "s%d" % (event.stream_ordering,)
) )
with PreserveLoggingContext(LoggingContext.current_context()):
room_sync_deferred = self.full_state_sync_for_archived_room( room_sync_deferred = self.full_state_sync_for_archived_room(
sync_config=sync_config, sync_config=sync_config,
room_id=event.room_id, room_id=event.room_id,

View File

@ -99,9 +99,8 @@ def request_handler(request_handler):
request_context.request = request_id request_context.request = request_id
with request.processing(): with request.processing():
try: try:
d = request_handler(self, request) with PreserveLoggingContext(request_context):
with PreserveLoggingContext(): yield request_handler(self, request)
yield d
except CodeMessageException as e: except CodeMessageException as e:
code = e.code code = e.code
if isinstance(e, SynapseError): if isinstance(e, SynapseError):

View File

@ -18,7 +18,8 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor, ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import StreamToken from synapse.types import StreamToken
import synapse.metrics import synapse.metrics
@ -73,6 +74,7 @@ class _NotifierUserStream(object):
self.current_token = current_token self.current_token = current_token
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred()) self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key, stream_id, time_now_ms): def notify(self, stream_key, stream_id, time_now_ms):
@ -88,6 +90,8 @@ class _NotifierUserStream(object):
) )
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred noify_deferred = self.notify_deferred
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred()) self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token) noify_deferred.callback(self.current_token)
@ -184,8 +188,6 @@ class Notifier(object):
lambda: count(bool, self.appservice_to_user_streams.values()), lambda: count(bool, self.appservice_to_user_streams.values()),
) )
@log_function
@defer.inlineCallbacks
def on_new_room_event(self, event, room_stream_id, max_room_stream_id, def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]): extra_users=[]):
""" Used by handlers to inform the notifier something has happened """ Used by handlers to inform the notifier something has happened
@ -199,8 +201,7 @@ class Notifier(object):
until all previous events have been persisted before notifying until all previous events have been persisted before notifying
the client streams. the client streams.
""" """
yield run_on_reactor() with PreserveLoggingContext():
self.pending_new_room_events.append(( self.pending_new_room_events.append((
room_stream_id, event, extra_users room_stream_id, event, extra_users
)) ))
@ -251,15 +252,13 @@ class Notifier(object):
extra_streams=app_streams, extra_streams=app_streams,
) )
@defer.inlineCallbacks
@log_function
def on_new_event(self, stream_key, new_token, users=[], rooms=[], def on_new_event(self, stream_key, new_token, users=[], rooms=[],
extra_streams=set()): extra_streams=set()):
""" Used to inform listeners that something has happend event wise. """ Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
yield run_on_reactor() with PreserveLoggingContext():
user_streams = set() user_streams = set()
for user in users: for user in users:
@ -325,6 +324,7 @@ class Notifier(object):
# that we don't miss any current_token updates. # that we don't miss any current_token updates.
prev_token = current_token prev_token = current_token
listener = user_stream.new_listener(prev_token) listener = user_stream.new_listener(prev_token)
with PreserveLoggingContext():
yield listener.deferred yield listener.deferred
except defer.CancelledError: except defer.CancelledError:
break break

View File

@ -111,7 +111,7 @@ class Pusher(object):
self.user_id, config, timeout=0, affect_presence=False self.user_id, config, timeout=0, affect_presence=False
) )
self.last_token = chunk['end'] self.last_token = chunk['end']
self.store.update_pusher_last_token( yield self.store.update_pusher_last_token(
self.app_id, self.pushkey, self.user_id, self.last_token self.app_id, self.pushkey, self.user_id, self.last_token
) )
logger.info("New pusher %s for user %s starting from token %s", logger.info("New pusher %s for user %s starting from token %s",

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from httppusher import HttpPusher from httppusher import HttpPusher
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.util.logcontext import preserve_fn
import logging import logging
@ -76,7 +77,7 @@ class PusherPool:
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
app_id, pushkey, p['user_name'] app_id, pushkey, p['user_name']
) )
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_user(self, user_id): def remove_pushers_by_user(self, user_id):
@ -91,7 +92,7 @@ class PusherPool:
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name'] p['app_id'], p['pushkey'], p['user_name']
) )
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind, def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind,
@ -110,7 +111,7 @@ class PusherPool:
lang=lang, lang=lang,
data=data, data=data,
) )
self._refresh_pusher(app_id, pushkey, user_id) yield self._refresh_pusher(app_id, pushkey, user_id)
def _create_pusher(self, pusherdict): def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http': if pusherdict['kind'] == 'http':
@ -166,7 +167,7 @@ class PusherPool:
if fullid in self.pushers: if fullid in self.pushers:
self.pushers[fullid].stop() self.pushers[fullid].stop()
self.pushers[fullid] = p self.pushers[fullid] = p
p.start() preserve_fn(p.start)()
logger.info("Started pushers") logger.info("Started pushers")

View File

@ -57,7 +57,7 @@ class AccountDataServlet(RestServlet):
user_id, account_data_type, body user_id, account_data_type, body
) )
yield self.notifier.on_new_event( self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id] "account_data_key", max_id, users=[user_id]
) )
@ -99,7 +99,7 @@ class RoomAccountDataServlet(RestServlet):
user_id, room_id, account_data_type, body user_id, room_id, account_data_type, body
) )
yield self.notifier.on_new_event( self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id] "account_data_key", max_id, users=[user_id]
) )

View File

@ -80,7 +80,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
yield self.notifier.on_new_event( self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id] "account_data_key", max_id, users=[user_id]
) )
@ -94,7 +94,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
yield self.notifier.on_new_event( self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id] "account_data_key", max_id, users=[user_id]
) )

View File

@ -15,7 +15,7 @@
import logging import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache from synapse.util.caches.descriptors import Cache
import synapse.metrics import synapse.metrics
@ -298,8 +298,8 @@ class SQLBaseStore(object):
func, *args, **kwargs func, *args, **kwargs
) )
result = yield preserve_context_over_fn( with PreserveLoggingContext():
self._db_pool.runWithConnection, result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs inner_func, *args, **kwargs
) )
@ -326,8 +326,8 @@ class SQLBaseStore(object):
return func(conn, *args, **kwargs) return func(conn, *args, **kwargs)
result = yield preserve_context_over_fn( with PreserveLoggingContext():
self._db_pool.runWithConnection, result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs inner_func, *args, **kwargs
) )

View File

@ -19,7 +19,7 @@ from twisted.internet import defer, reactor
from synapse.events import FrozenEvent, USE_FROZEN_DICTS from synapse.events import FrozenEvent, USE_FROZEN_DICTS
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.util.logcontext import preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@ -664,6 +664,7 @@ class EventsStore(SQLBaseStore):
for ids, d in lst: for ids, d in lst:
if not d.called: if not d.called:
try: try:
with PreserveLoggingContext():
d.callback([ d.callback([
res[i] res[i]
for i in ids for i in ids
@ -671,6 +672,7 @@ class EventsStore(SQLBaseStore):
]) ])
except: except:
logger.exception("Failed to callback") logger.exception("Failed to callback")
with PreserveLoggingContext():
reactor.callFromThread(fire, event_list, row_dict) reactor.callFromThread(fire, event_list, row_dict)
except Exception as e: except Exception as e:
logger.exception("do_fetch") logger.exception("do_fetch")
@ -679,9 +681,11 @@ class EventsStore(SQLBaseStore):
def fire(evs): def fire(evs):
for _, d in evs: for _, d in evs:
if not d.called: if not d.called:
with PreserveLoggingContext():
d.errback(e) d.errback(e)
if event_list: if event_list:
with PreserveLoggingContext():
reactor.callFromThread(fire, event_list) reactor.callFromThread(fire, event_list)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -709,18 +713,20 @@ class EventsStore(SQLBaseStore):
should_start = False should_start = False
if should_start: if should_start:
with PreserveLoggingContext():
self.runWithConnection( self.runWithConnection(
self._do_fetch self._do_fetch
) )
rows = yield preserve_context_over_deferred(events_d) with PreserveLoggingContext():
rows = yield events_d
if not allow_rejected: if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]] rows[:] = [r for r in rows if not r["rejects"]]
res = yield defer.gatherResults( res = yield defer.gatherResults(
[ [
self._get_event_from_row( preserve_fn(self._get_event_from_row)(
row["internal_metadata"], row["json"], row["redacts"], row["internal_metadata"], row["json"], row["redacts"],
check_redacted=check_redacted, check_redacted=check_redacted,
get_prev_content=get_prev_content, get_prev_content=get_prev_content,

View File

@ -68,8 +68,9 @@ class PresenceStore(SQLBaseStore):
for row in rows for row in rows
}) })
@defer.inlineCallbacks
def set_presence_state(self, user_localpart, new_state): def set_presence_state(self, user_localpart, new_state):
res = self._simple_update_one( res = yield self._simple_update_one(
table="presence", table="presence",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
updatevalues={"state": new_state["state"], updatevalues={"state": new_state["state"],
@ -79,7 +80,7 @@ class PresenceStore(SQLBaseStore):
) )
self.get_presence_state.invalidate((user_localpart,)) self.get_presence_state.invalidate((user_localpart,))
return res defer.returnValue(res)
def allow_presence_visible(self, observed_localpart, observer_userid): def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert( return self._simple_insert(

View File

@ -39,6 +39,7 @@ from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.logcontext import preserve_fn
import logging import logging
@ -170,12 +171,12 @@ class StreamStore(SQLBaseStore):
room_ids = list(room_ids) room_ids = list(room_ids)
for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)): for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
res = yield defer.gatherResults([ res = yield defer.gatherResults([
self.get_room_events_stream_for_room( preserve_fn(self.get_room_events_stream_for_room)(
room_id, from_key, to_key, limit room_id, from_key, to_key, limit,
).addCallback(lambda r, rm: (rm, r), room_id) )
for room_id in room_ids for room_id in room_ids
]) ])
results.update(dict(res)) results.update(dict(zip(rm_ids, res)))
defer.returnValue(results) defer.returnValue(results)

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer, reactor, task from twisted.internet import defer, reactor, task
@ -61,10 +61,8 @@ class Clock(object):
*args: Postional arguments to pass to function. *args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function. **kwargs: Key arguments to pass to function.
""" """
current_context = LoggingContext.current_context()
def wrapped_callback(*args, **kwargs): def wrapped_callback(*args, **kwargs):
with PreserveLoggingContext(current_context): with PreserveLoggingContext():
callback(*args, **kwargs) callback(*args, **kwargs)
with PreserveLoggingContext(): with PreserveLoggingContext():

View File

@ -16,13 +16,16 @@
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from .logcontext import preserve_context_over_deferred from .logcontext import PreserveLoggingContext
@defer.inlineCallbacks
def sleep(seconds): def sleep(seconds):
d = defer.Deferred() d = defer.Deferred()
with PreserveLoggingContext():
reactor.callLater(seconds, d.callback, seconds) reactor.callLater(seconds, d.callback, seconds)
return preserve_context_over_deferred(d) res = yield d
defer.returnValue(res)
def run_on_reactor(): def run_on_reactor():
@ -54,6 +57,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_result", (True, r)) object.__setattr__(self, "_result", (True, r))
while self._observers: while self._observers:
try: try:
# TODO: Handle errors here.
self._observers.pop().callback(r) self._observers.pop().callback(r)
except: except:
pass pass
@ -63,6 +67,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_result", (False, f)) object.__setattr__(self, "_result", (False, f))
while self._observers: while self._observers:
try: try:
# TODO: Handle errors here.
self._observers.pop().errback(f) self._observers.pop().errback(f)
except: except:
pass pass

View File

@ -18,6 +18,9 @@ from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache from synapse.util.caches.treecache import TreeCache
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
)
from . import caches_by_name, DEBUG_CACHES, cache_counter from . import caches_by_name, DEBUG_CACHES, cache_counter
@ -190,7 +193,7 @@ class CacheDescriptor(object):
defer.returnValue(cached_result) defer.returnValue(cached_result)
observer.addCallback(check_result) observer.addCallback(check_result)
return observer return preserve_context_over_deferred(observer)
except KeyError: except KeyError:
# Get the sequence number of the cache before reading from the # Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated # database so that we can tell if the cache is invalidated
@ -198,6 +201,7 @@ class CacheDescriptor(object):
sequence = self.cache.sequence sequence = self.cache.sequence
ret = defer.maybeDeferred( ret = defer.maybeDeferred(
preserve_context_over_fn,
self.function_to_call, self.function_to_call,
obj, *args, **kwargs obj, *args, **kwargs
) )
@ -211,7 +215,7 @@ class CacheDescriptor(object):
ret = ObservableDeferred(ret, consumeErrors=True) ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret) self.cache.update(sequence, cache_key, ret)
return ret.observe() return preserve_context_over_deferred(ret.observe())
wrapped.invalidate = self.cache.invalidate wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all wrapped.invalidate_all = self.cache.invalidate_all
@ -299,6 +303,7 @@ class CacheListDescriptor(object):
args_to_call[self.list_name] = missing args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred( ret_d = defer.maybeDeferred(
preserve_context_over_fn,
self.function_to_call, self.function_to_call,
**args_to_call **args_to_call
) )
@ -308,6 +313,7 @@ class CacheListDescriptor(object):
# We need to create deferreds for each arg in the list so that # We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache. # we can insert the new deferred into the cache.
for arg in missing: for arg in missing:
with PreserveLoggingContext():
observer = ret_d.observe() observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg) observer.addCallback(lambda r, arg: r.get(arg, None), arg)
@ -327,10 +333,10 @@ class CacheListDescriptor(object):
cached[arg] = res cached[arg] = res
return defer.gatherResults( return preserve_context_over_deferred(defer.gatherResults(
cached.values(), cached.values(),
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)) ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
obj.__dict__[self.orig.__name__] = wrapped obj.__dict__[self.orig.__name__] = wrapped

View File

@ -87,7 +87,8 @@ class SnapshotCache(object):
# expire from the rotation of that cache. # expire from the rotation of that cache.
self.next_result_cache[key] = result self.next_result_cache[key] = result
self.pending_result_cache.pop(key, None) self.pending_result_cache.pop(key, None)
return r
result.observe().addBoth(shuffle_along) result.addBoth(shuffle_along)
return result.observe() return result.observe()

View File

@ -15,9 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logcontext import ( from synapse.util.logcontext import PreserveLoggingContext
PreserveLoggingContext, preserve_context_over_deferred,
)
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
@ -97,6 +95,7 @@ class Signal(object):
Each observer callable may return a Deferred.""" Each observer callable may return a Deferred."""
self.observers.append(observer) self.observers.append(observer)
@defer.inlineCallbacks
def fire(self, *args, **kwargs): def fire(self, *args, **kwargs):
"""Invokes every callable in the observer list, passing in the args and """Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is kwargs. Exceptions thrown by observers are logged but ignored. It is
@ -116,6 +115,7 @@ class Signal(object):
failure.getTracebackObject())) failure.getTracebackObject()))
if not self.suppress_failures: if not self.suppress_failures:
return failure return failure
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
with PreserveLoggingContext(): with PreserveLoggingContext():
@ -124,8 +124,11 @@ class Signal(object):
for observer in self.observers for observer in self.observers
] ]
d = defer.gatherResults(deferreds, consumeErrors=True) res = yield defer.gatherResults(
deferreds, consumeErrors=True
).addErrback(unwrapFirstError)
d.addErrback(unwrapFirstError) defer.returnValue(res)
return preserve_context_over_deferred(d) def __repr__(self):
return "<Signal name=%r>" % (self.name,)

View File

@ -48,7 +48,7 @@ class LoggingContext(object):
__slots__ = [ __slots__ = [
"parent_context", "name", "usage_start", "usage_end", "main_thread", "parent_context", "name", "usage_start", "usage_end", "main_thread",
"__dict__", "tag", "__dict__", "tag", "alive",
] ]
thread_local = threading.local() thread_local = threading.local()
@ -88,6 +88,7 @@ class LoggingContext(object):
self.usage_start = None self.usage_start = None
self.main_thread = threading.current_thread() self.main_thread = threading.current_thread()
self.tag = "" self.tag = ""
self.alive = True
def __str__(self): def __str__(self):
return "%s@%x" % (self.name, id(self)) return "%s@%x" % (self.name, id(self))
@ -106,6 +107,7 @@ class LoggingContext(object):
The context that was previously active The context that was previously active
""" """
current = cls.current_context() current = cls.current_context()
if current is not context: if current is not context:
current.stop() current.stop()
cls.thread_local.current_context = context cls.thread_local.current_context = context
@ -117,6 +119,7 @@ class LoggingContext(object):
if self.parent_context is not None: if self.parent_context is not None:
raise Exception("Attempt to enter logging context multiple times") raise Exception("Attempt to enter logging context multiple times")
self.parent_context = self.set_current_context(self) self.parent_context = self.set_current_context(self)
self.alive = True
return self return self
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
@ -136,6 +139,7 @@ class LoggingContext(object):
self self
) )
self.parent_context = None self.parent_context = None
self.alive = False
def __getattr__(self, name): def __getattr__(self, name):
"""Delegate member lookup to parent context""" """Delegate member lookup to parent context"""
@ -213,7 +217,7 @@ class PreserveLoggingContext(object):
exited. Used to restore the context after a function using exited. Used to restore the context after a function using
@defer.inlineCallbacks is resumed by a callback from the reactor.""" @defer.inlineCallbacks is resumed by a callback from the reactor."""
__slots__ = ["current_context", "new_context"] __slots__ = ["current_context", "new_context", "has_parent"]
def __init__(self, new_context=LoggingContext.sentinel): def __init__(self, new_context=LoggingContext.sentinel):
self.new_context = new_context self.new_context = new_context
@ -224,11 +228,26 @@ class PreserveLoggingContext(object):
self.new_context self.new_context
) )
if self.current_context:
self.has_parent = self.current_context.parent_context is not None
if not self.current_context.alive:
logger.warn(
"Entering dead context: %s",
self.current_context,
)
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
"""Restores the current logging context""" """Restores the current logging context"""
LoggingContext.set_current_context(self.current_context) context = LoggingContext.set_current_context(self.current_context)
if context != self.new_context:
logger.warn(
"Unexpected logging context: %s is not %s",
context, self.new_context,
)
if self.current_context is not LoggingContext.sentinel: if self.current_context is not LoggingContext.sentinel:
if self.current_context.parent_context is None: if not self.current_context.alive:
logger.warn( logger.warn(
"Restoring dead context: %s", "Restoring dead context: %s",
self.current_context, self.current_context,
@ -289,3 +308,74 @@ def preserve_context_over_deferred(deferred):
d = _PreservingContextDeferred(current_context) d = _PreservingContextDeferred(current_context)
deferred.chainDeferred(d) deferred.chainDeferred(d)
return d return d
def preserve_fn(f):
"""Ensures that function is called with correct context and that context is
restored after return. Useful for wrapping functions that return a deferred
which you don't yield on.
"""
current = LoggingContext.current_context()
def g(*args, **kwargs):
with PreserveLoggingContext(current):
return f(*args, **kwargs)
return g
# modules to ignore in `logcontext_tracer`
_to_ignore = [
"synapse.util.logcontext",
"synapse.http.server",
"synapse.storage._base",
"synapse.util.async",
]
def logcontext_tracer(frame, event, arg):
"""A tracer that logs whenever a logcontext "unexpectedly" changes within
a function. Probably inaccurate.
Use by calling `sys.settrace(logcontext_tracer)` in the main thread.
"""
if event == 'call':
name = frame.f_globals["__name__"]
if name.startswith("synapse"):
if name == "synapse.util.logcontext":
if frame.f_code.co_name in ["__enter__", "__exit__"]:
tracer = frame.f_back.f_trace
if tracer:
tracer.just_changed = True
tracer = frame.f_trace
if tracer:
return tracer
if not any(name.startswith(ig) for ig in _to_ignore):
return LineTracer()
class LineTracer(object):
__slots__ = ["context", "just_changed"]
def __init__(self):
self.context = LoggingContext.current_context()
self.just_changed = False
def __call__(self, frame, event, arg):
if event in 'line':
if self.just_changed:
self.context = LoggingContext.current_context()
self.just_changed = False
else:
c = LoggingContext.current_context()
if c != self.context:
logger.info(
"Context changed! %s -> %s, %s, %s",
self.context, c,
frame.f_code.co_filename, frame.f_lineno
)
self.context = c
return self

View File

@ -168,3 +168,38 @@ def trace_function(f):
wrapped.__name__ = func_name wrapped.__name__ = func_name
return wrapped return wrapped
def get_previous_frames():
s = inspect.currentframe().f_back.f_back
to_return = []
while s:
if s.f_globals["__name__"].startswith("synapse"):
filename, lineno, function, _, _ = inspect.getframeinfo(s)
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
to_return.append("{{ %s:%d %s - Args: %s }}" % (
filename, lineno, function, args_string
))
s = s.f_back
return ", ". join(to_return)
def get_previous_frame(ignore=[]):
s = inspect.currentframe().f_back.f_back
while s:
if s.f_globals["__name__"].startswith("synapse"):
if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore):
filename, lineno, function, _, _ = inspect.getframeinfo(s)
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
return "{{ %s:%d %s - Args: %s }}" % (
filename, lineno, function, args_string
)
s = s.f_back
return None

View File

@ -68,16 +68,18 @@ class Measure(object):
block_timer.inc_by(duration, self.name) block_timer.inc_by(duration, self.name)
context = LoggingContext.current_context() context = LoggingContext.current_context()
if not context:
return
if context != self.start_context: if context != self.start_context:
logger.warn( logger.warn(
"Context have unexpectedly changed %r, %r", "Context have unexpectedly changed from '%s' to '%s'. (%r)",
context, self.start_context context, self.start_context, self.name
) )
return return
if not context:
logger.warn("Expected context. (%r)", self.name)
return
ru_utime, ru_stime = context.get_resource_usage() ru_utime, ru_stime = context.get_resource_usage()
block_ru_utime.inc_by(ru_utime, self.name) block_ru_utime.inc_by(ru_utime, self.name)

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError from synapse.api.errors import LimitExceededError
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.logcontext import preserve_fn
import collections import collections
import contextlib import contextlib
@ -163,7 +164,7 @@ class _PerHostRatelimiter(object):
"Ratelimit [%s]: sleeping req", "Ratelimit [%s]: sleeping req",
id(request_id), id(request_id),
) )
ret_defer = sleep(self.sleep_msec / 1000.0) ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0)
self.sleeping_requests.add(request_id) self.sleeping_requests.add(request_id)