Turn _start_key_lookups into an inlineCallbacks function

... which means that logcontexts can be correctly preserved for the stuff it
does.

get_server_verify_keys is now called with the logcontext, so needs to
preserve_fn when it fires off its nested inlineCallbacks function.

Also renames get_server_verify_keys to reflect the fact it's meant to be
private.
This commit is contained in:
Richard van der Hoff 2017-09-20 01:32:42 +01:00
parent abdefb8a01
commit c5b0e9f485

View File

@ -123,7 +123,7 @@ class Keyring(object):
verify_requests.append(verify_request) verify_requests.append(verify_request)
self._start_key_lookups(verify_requests) preserve_fn(self._start_key_lookups)(verify_requests)
# Pass those keys to handle_key_deferred so that the json object # Pass those keys to handle_key_deferred so that the json object
# signatures can be verified # signatures can be verified
@ -132,6 +132,7 @@ class Keyring(object):
for rq in verify_requests for rq in verify_requests
] ]
@defer.inlineCallbacks
def _start_key_lookups(self, verify_requests): def _start_key_lookups(self, verify_requests):
"""Sets off the key fetches for each verify request """Sets off the key fetches for each verify request
@ -151,47 +152,43 @@ class Keyring(object):
for rq in verify_requests for rq in verify_requests
} }
with PreserveLoggingContext(): # We want to wait for any previous lookups to complete before
# proceeding.
yield self.wait_for_previous_lookups(
[rq.server_name for rq in verify_requests],
server_to_deferred,
)
# We want to wait for any previous lookups to complete before # Actually start fetching keys.
# proceeding. self._get_server_verify_keys(verify_requests)
wait_on_deferred = self.wait_for_previous_lookups(
[rq.server_name for rq in verify_requests], # When we've finished fetching all the keys for a given server_name,
server_to_deferred, # resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
#
# map from server name to a set of request ids
server_to_request_ids = {}
for verify_request in verify_requests:
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id)
def remove_deferreds(res, verify_request):
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids[server_name].discard(request_id)
if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None)
if d:
d.callback(None)
return res
for verify_request in verify_requests:
verify_request.deferred.addBoth(
remove_deferreds, verify_request,
) )
# Actually start fetching keys.
wait_on_deferred.addBoth(
lambda _: self.get_server_verify_keys(verify_requests)
)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
#
# map from server name to a set of request ids
server_to_request_ids = {}
for verify_request in verify_requests:
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id)
def remove_deferreds(res, verify_request):
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids[server_name].discard(request_id)
if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None)
if d:
d.callback(None)
return res
for verify_request in verify_requests:
verify_request.deferred.addBoth(
remove_deferreds, verify_request,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_previous_lookups(self, server_names, server_to_deferred): def wait_for_previous_lookups(self, server_names, server_to_deferred):
"""Waits for any previous key lookups for the given servers to finish. """Waits for any previous key lookups for the given servers to finish.
@ -227,7 +224,7 @@ class Keyring(object):
self.key_downloads[server_name] = deferred self.key_downloads[server_name] = deferred
deferred.addBoth(rm, server_name) deferred.addBoth(rm, server_name)
def get_server_verify_keys(self, verify_requests): def _get_server_verify_keys(self, verify_requests):
"""Tries to find at least one key for each verify request """Tries to find at least one key for each verify request
For each verify_request, verify_request.deferred is called back with For each verify_request, verify_request.deferred is called back with
@ -312,7 +309,7 @@ class Keyring(object):
if not verify_request.deferred.called: if not verify_request.deferred.called:
verify_request.deferred.errback(err) verify_request.deferred.errback(err)
do_iterations().addErrback(on_err) preserve_fn(do_iterations)().addErrback(on_err)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids): def get_keys_from_store(self, server_name_and_key_ids):