Implement locks using create_observer for fetching media and server keys

This commit is contained in:
Erik Johnston 2015-04-27 14:20:26 +01:00
parent 1c82fbd2eb
commit e701aec2d1
2 changed files with 87 additions and 65 deletions

View File

@ -24,6 +24,8 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter from synapse.util.retryutils import get_retry_limiter
from synapse.util.async import create_observer
from OpenSSL import crypto from OpenSSL import crypto
import logging import logging
@ -38,6 +40,8 @@ class Keyring(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.hs = hs self.hs = hs
self.key_downloads = {}
@defer.inlineCallbacks @defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object): def verify_json_for_server(self, server_name, json_object):
logger.debug("Verifying for %s", server_name) logger.debug("Verifying for %s", server_name)
@ -97,6 +101,8 @@ class Keyring(object):
defer.returnValue(cached[0]) defer.returnValue(cached[0])
return return
@defer.inlineCallbacks
def fetch_keys():
# Try to fetch the key from the remote server. # Try to fetch the key from the remote server.
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
@ -170,3 +176,17 @@ class Keyring(object):
return return
raise ValueError("No verification key found for given key ids") raise ValueError("No verification key found for given key ids")
download = self.key_downloads.get(server_name)
if download is None:
download = fetch_keys()
self.key_downloads[server_name] = download
@download.addBoth
def callback(ret):
del self.key_downloads[server_name]
return ret
r = yield create_observer(download)
defer.returnValue(r)

View File

@ -25,6 +25,8 @@ from twisted.internet import defer
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from synapse.util.async import create_observer
import os import os
import logging import logging
@ -87,7 +89,7 @@ class BaseMediaResource(Resource):
def callback(media_info): def callback(media_info):
del self.downloads[key] del self.downloads[key]
return media_info return media_info
return download return create_observer(download)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id): def _get_remote_media_impl(self, server_name, media_id):