Fix up logcontexts

This commit is contained in:
Erik Johnston 2016-02-04 10:22:44 +00:00
parent 13e6262659
commit 2c1fbea531
31 changed files with 356 additions and 229 deletions

View file

@ -18,6 +18,10 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import (
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
preserve_fn
)
from twisted.internet import defer
@ -142,40 +146,43 @@ class Keyring(object):
for server_name, _ in server_and_json
}
# We want to wait for any previous lookups to complete before
# proceeding.
wait_on_deferred = self.wait_for_previous_lookups(
[server_name for server_name, _ in server_and_json],
server_to_deferred,
)
with PreserveLoggingContext():
# Actually start fetching keys.
wait_on_deferred.addBoth(
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
)
# We want to wait for any previous lookups to complete before
# proceeding.
wait_on_deferred = self.wait_for_previous_lookups(
[server_name for server_name, _ in server_and_json],
server_to_deferred,
)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
server_to_gids = {}
# Actually start fetching keys.
wait_on_deferred.addBoth(
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
)
def remove_deferreds(res, server_name, group_id):
server_to_gids[server_name].discard(group_id)
if not server_to_gids[server_name]:
d = server_to_deferred.pop(server_name, None)
if d:
d.callback(None)
return res
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
server_to_gids = {}
for g_id, deferred in deferreds.items():
server_name = group_id_to_group[g_id].server_name
server_to_gids.setdefault(server_name, set()).add(g_id)
deferred.addBoth(remove_deferreds, server_name, g_id)
def remove_deferreds(res, server_name, group_id):
server_to_gids[server_name].discard(group_id)
if not server_to_gids[server_name]:
d = server_to_deferred.pop(server_name, None)
if d:
d.callback(None)
return res
for g_id, deferred in deferreds.items():
server_name = group_id_to_group[g_id].server_name
server_to_gids.setdefault(server_name, set()).add(g_id)
deferred.addBoth(remove_deferreds, server_name, g_id)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
return [
handle_key_deferred(
preserve_context_over_fn(
handle_key_deferred,
group_id_to_group[g_id],
deferreds[g_id],
)
@ -198,12 +205,13 @@ class Keyring(object):
if server_name in self.key_downloads
]
if wait_on:
yield defer.DeferredList(wait_on)
with PreserveLoggingContext():
yield defer.DeferredList(wait_on)
else:
break
for server_name, deferred in server_to_deferred.items():
d = ObservableDeferred(deferred)
d = ObservableDeferred(preserve_context_over_deferred(deferred))
self.key_downloads[server_name] = d
def rm(r, server_name):
@ -244,12 +252,13 @@ class Keyring(object):
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
group_id_to_deferred[group.group_id].callback((
group.group_id,
group.server_name,
key_id,
merged_results[group.server_name][key_id],
))
with PreserveLoggingContext():
group_id_to_deferred[group.group_id].callback((
group.group_id,
group.server_name,
key_id,
merged_results[group.server_name][key_id],
))
break
else:
missing_groups.setdefault(
@ -504,7 +513,7 @@ class Keyring(object):
yield defer.gatherResults(
[
self.store_keys(
preserve_fn(self.store_keys)(
server_name=key_server_name,
from_server=server_name,
verify_keys=verify_keys,
@ -573,7 +582,7 @@ class Keyring(object):
yield defer.gatherResults(
[
self.store.store_server_keys_json(
preserve_fn(self.store.store_server_keys_json)(
server_name=server_name,
key_id=key_id,
from_server=server_name,
@ -675,7 +684,7 @@ class Keyring(object):
# TODO(markjh): Store whether the keys have expired.
yield defer.gatherResults(
[
self.store.store_server_verify_key(
preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key
)
for key_id, key in verify_keys.items()