Implement bulk verify_signed_json API

This commit is contained in:
Erik Johnston 2015-06-26 09:52:24 +01:00
parent 6924852592
commit b5f55a1d85
4 changed files with 439 additions and 215 deletions

View File

@ -25,11 +25,11 @@ from syutil.base64util import decode_base64, encode_base64
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred
from OpenSSL import crypto from OpenSSL import crypto
from collections import namedtuple
import urllib import urllib
import hashlib import hashlib
import logging import logging
@ -38,6 +38,9 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
class Keyring(object): class Keyring(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -49,141 +52,274 @@ class Keyring(object):
self.key_downloads = {} self.key_downloads = {}
@defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object): def verify_json_for_server(self, server_name, json_object):
logger.debug("Verifying for %s", server_name) return self.verify_json_objects_for_server(
key_ids = signature_ids(json_object, server_name) [(server_name, json_object)]
if not key_ids: )[0]
raise SynapseError(
400,
"Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
)
try:
verify_key = yield self.get_server_verify_key(server_name, key_ids)
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
502,
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
)
except Exception as e:
logger.warn(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, key_ids),
Codes.UNAUTHORIZED,
)
try: def verify_json_objects_for_server(self, server_and_json):
verify_signed_json(json_object, server_name, verify_key) """Bulk verfies signatures of json objects, bulk fetching keys as
except: necessary.
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s" % (
server_name, verify_key.alg, verify_key.version
),
Codes.UNAUTHORIZED,
)
@defer.inlineCallbacks
def get_server_verify_key(self, server_name, key_ids):
"""Finds a verification key for the server with one of the key ids.
Trys to fetch the key from a trusted perspective server first.
Args: Args:
server_name(str): The name of the server to fetch a key for. server_and_json (list): List of pairs of (server_name, json_object)
keys_ids (list of str): The key_ids to check for.
Returns:
list of deferreds indicating success or failure to verify each
json object's signature for the given server_name.
""" """
cached = yield self.store.get_server_verify_keys(server_name, key_ids) group_id_to_json = {}
group_id_to_group = {}
group_ids = []
if cached: next_group_id = 0
defer.returnValue(cached[0]) deferreds = {}
return
download = self.key_downloads.get(server_name) for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name)
group_id = next_group_id
next_group_id += 1
group_ids.append(group_id)
if download is None: key_ids = signature_ids(json_object, server_name)
download = self._get_server_verify_key_impl(server_name, key_ids) if not key_ids:
download = ObservableDeferred( deferreds[group_id] = defer.fail(SynapseError(
download, 400,
consumeErrors=True "Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
))
group = KeyGroup(server_name, group_id, key_ids)
group_id_to_group[group_id] = group
group_id_to_json[group_id] = json_object
@defer.inlineCallbacks
def handle_key_deferred(group, deferred):
server_name = group.server_name
try:
_, _, key_id, verify_key = yield deferred
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
502,
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
)
except Exception as e:
logger.exception(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, key_ids),
Codes.UNAUTHORIZED,
)
json_object = group_id_to_json[group.group_id]
try:
verify_signed_json(json_object, server_name, verify_key)
except:
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s" % (
server_name, verify_key.alg, verify_key.version
),
Codes.UNAUTHORIZED,
)
deferreds.update(self.get_server_verify_keys(
group_id_to_group
))
return [
handle_key_deferred(
group_id_to_group[g_id],
deferreds[g_id],
) )
self.key_downloads[server_name] = download for g_id in group_ids
]
@download.addBoth def get_server_verify_keys(self, group_id_to_group):
def callback(ret): """Takes a dict of KeyGroups and tries to find at least one key for
del self.key_downloads[server_name] each group.
return ret """
r = yield download.observe() # These are functions that produce keys given a list of key ids
defer.returnValue(r) key_fetch_fns = (
self.get_keys_from_store, # First try the local store
self.get_keys_from_perspectives, # Then try via perspectives
self.get_keys_from_server, # Then try directly
)
group_deferreds = {
group_id: defer.Deferred()
for group_id in group_id_to_group
}
@defer.inlineCallbacks
def do_iterations():
merged_results = {}
missing_keys = {
group.server_name: key_id
for group in group_id_to_group.values()
for key_id in group.key_ids
}
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
# We now need to figure out which groups we have keys for
# and which we don't
missing_groups = {}
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
group_deferreds.pop(group.group_id).callback((
group.group_id,
group.server_name,
key_id,
merged_results[group.server_name][key_id],
))
break
else:
missing_groups.setdefault(
group.server_name, []
).append(group)
if not missing_groups:
break
missing_keys = {
server_name: set(
key_id for group in groups for key_id in group.key_ids
)
for server_name, groups in missing_groups.items()
}
for group in missing_groups.values():
group_deferreds.pop(group.group_id).errback(SynapseError(
401,
"No key for %s with id %s" % (
group.server_name, group.key_ids,
),
Codes.UNAUTHORIZED,
))
def on_err(err):
for deferred in group_deferreds.values():
deferred.errback(err)
group_deferreds.clear()
do_iterations().addErrback(on_err)
return group_deferreds
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_server_verify_key_impl(self, server_name, key_ids): def get_keys_from_store(self, server_name_and_key_ids):
keys = None res = yield defer.gatherResults(
[
self.store.get_server_verify_keys(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
defer.returnValue(dict(zip(
[server_name for server_name, _ in server_name_and_key_ids],
res
)))
@defer.inlineCallbacks
def get_keys_from_perspectives(self, server_name_and_key_ids):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_key(perspective_name, perspective_keys): def get_key(perspective_name, perspective_keys):
try: try:
result = yield self.get_server_verify_key_v2_indirect( result = yield self.get_server_verify_key_v2_indirect(
server_name, key_ids, perspective_name, perspective_keys server_name_and_key_ids, perspective_name, perspective_keys
) )
defer.returnValue(result) defer.returnValue(result)
except Exception as e: except Exception as e:
logging.info( logger.exception(
"Unable to getting key %r for %r from %r: %s %s", "Unable to get key from %r: %s %s",
key_ids, server_name, perspective_name, perspective_name,
type(e).__name__, str(e.message), type(e).__name__, str(e.message),
) )
defer.returnValue({})
perspective_results = yield defer.gatherResults([ results = yield defer.gatherResults(
get_key(p_name, p_keys) [
for p_name, p_keys in self.perspective_servers.items() get_key(p_name, p_keys)
]) for p_name, p_keys in self.perspective_servers.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
for results in perspective_results: union_of_keys = {}
if results is not None: for result in results:
keys = results for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys)
limiter = yield get_retry_limiter( defer.returnValue(union_of_keys)
server_name,
self.clock,
self.store,
)
with limiter: @defer.inlineCallbacks
if not keys: def get_keys_from_server(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(server_name, key_ids):
limiter = yield get_retry_limiter(
server_name,
self.clock,
self.store,
)
with limiter:
keys = None
try: try:
keys = yield self.get_server_verify_key_v2_direct( keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids server_name, key_ids
) )
except Exception as e: except Exception as e:
logging.info( logger.info(
"Unable to getting key %r for %r directly: %s %s", "Unable to getting key %r for %r directly: %s %s",
key_ids, server_name, key_ids, server_name,
type(e).__name__, str(e.message), type(e).__name__, str(e.message),
) )
if not keys: if not keys:
keys = yield self.get_server_verify_key_v1_direct( keys = yield self.get_server_verify_key_v1_direct(
server_name, key_ids server_name, key_ids
) )
for key_id in key_ids: keys = {server_name: keys}
if key_id in keys:
defer.returnValue(keys[key_id]) defer.returnValue(keys)
return
raise ValueError("No verification key found for given key ids") results = yield defer.gatherResults(
[
get_key(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
merged = {}
for result in results:
merged.update(result)
defer.returnValue({
server_name: keys
for server_name, keys in merged.items()
if keys
})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_indirect(self, server_name, key_ids, def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
perspective_name, perspective_name,
perspective_keys): perspective_keys):
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
@ -204,6 +340,7 @@ class Keyring(object):
u"minimum_valid_until_ts": 0 u"minimum_valid_until_ts": 0
} for key_id in key_ids } for key_id in key_ids
} }
for server_name, key_ids in server_names_and_key_ids
} }
}, },
) )
@ -243,23 +380,29 @@ class Keyring(object):
" server %r" % (perspective_name,) " server %r" % (perspective_name,)
) )
response_keys = yield self.process_v2_response( processed_response = yield self.process_v2_response(
server_name, perspective_name, response perspective_name, response
) )
keys.update(response_keys) for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys)
yield self.store_keys( yield defer.gatherResults(
server_name=server_name, [
from_server=perspective_name, self.store_keys(
verify_keys=keys, server_name=server_name,
) from_server=perspective_name,
verify_keys=response_keys,
)
for server_name, response_keys in keys.items()
],
consumeErrors=True
).addErrback(unwrapFirstError)
defer.returnValue(keys) defer.returnValue(keys)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids): def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {} keys = {}
for requested_key_id in key_ids: for requested_key_id in key_ids:
@ -295,25 +438,30 @@ class Keyring(object):
raise ValueError("TLS certificate not allowed by fingerprints") raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response( response_keys = yield self.process_v2_response(
server_name=server_name,
from_server=server_name, from_server=server_name,
requested_id=requested_key_id, requested_ids=[requested_key_id],
response_json=response, response_json=response,
) )
keys.update(response_keys) keys.update(response_keys)
yield self.store_keys( yield defer.gatherResults(
server_name=server_name, [
from_server=server_name, self.store_keys(
verify_keys=keys, server_name=key_server_name,
) from_server=server_name,
verify_keys=verify_keys,
)
for key_server_name, verify_keys in keys.items()
],
consumeErrors=True
).addErrback(unwrapFirstError)
defer.returnValue(keys) defer.returnValue(keys)
@defer.inlineCallbacks @defer.inlineCallbacks
def process_v2_response(self, server_name, from_server, response_json, def process_v2_response(self, from_server, response_json,
requested_id=None): requested_ids=[]):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
response_keys = {} response_keys = {}
verify_keys = {} verify_keys = {}
@ -335,6 +483,8 @@ class Keyring(object):
verify_key.time_added = time_now_ms verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key old_verify_keys[key_id] = verify_key
results = {}
server_name = response_json["server_name"]
for key_id in response_json["signatures"].get(server_name, {}): for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]: if key_id not in response_json["verify_keys"]:
raise ValueError( raise ValueError(
@ -357,28 +507,31 @@ class Keyring(object):
signed_key_json_bytes = encode_canonical_json(signed_key_json) signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until_ts"] ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
updated_key_ids = set() updated_key_ids = set(requested_ids)
if requested_id is not None:
updated_key_ids.add(requested_id)
updated_key_ids.update(verify_keys) updated_key_ids.update(verify_keys)
updated_key_ids.update(old_verify_keys) updated_key_ids.update(old_verify_keys)
response_keys.update(verify_keys) response_keys.update(verify_keys)
response_keys.update(old_verify_keys) response_keys.update(old_verify_keys)
for key_id in updated_key_ids: yield defer.gatherResults(
yield self.store.store_server_keys_json( [
server_name=server_name, self.store.store_server_keys_json(
key_id=key_id, server_name=server_name,
from_server=server_name, key_id=key_id,
ts_now_ms=time_now_ms, from_server=server_name,
ts_expires_ms=ts_valid_until_ms, ts_now_ms=time_now_ms,
key_json_bytes=signed_key_json_bytes, ts_expires_ms=ts_valid_until_ms,
) key_json_bytes=signed_key_json_bytes,
)
for key_id in updated_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
defer.returnValue(response_keys) results[server_name] = response_keys
raise ValueError("No verification key found for given key ids") defer.returnValue(results)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v1_direct(self, server_name, key_ids): def get_server_verify_key_v1_direct(self, server_name, key_ids):
@ -462,8 +615,13 @@ class Keyring(object):
Returns: Returns:
A deferred that completes when the keys are stored. A deferred that completes when the keys are stored.
""" """
for key_id, key in verify_keys.items(): # TODO(markjh): Store whether the keys have expired.
# TODO(markjh): Store whether the keys have expired. yield defer.gatherResults(
yield self.store.store_server_verify_key( [
server_name, server_name, key.time_added, key self.store.store_server_verify_key(
) server_name, server_name, key.time_added, key
)
for key_id, key in verify_keys.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)

View File

@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
class FederationBase(object): class FederationBase(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False): def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
include_none=False):
"""Takes a list of PDUs and checks the signatures and hashs of each """Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of the database and if not then request if from the originating server of
@ -50,84 +51,108 @@ class FederationBase(object):
Returns: Returns:
Deferred : A list of PDUs that have valid signatures and hashes. Deferred : A list of PDUs that have valid signatures and hashes.
""" """
deferreds = self._check_sigs_and_hashes(pdus)
signed_pdus = [] def callback(pdu):
return pdu
@defer.inlineCallbacks def errback(failure, pdu):
def do(pdu): failure.trap(SynapseError)
try: return None
new_pdu = yield self._check_sigs_and_hash(pdu)
signed_pdus.append(new_pdu)
except SynapseError:
# FIXME: We should handle signature failures more gracefully.
def try_local_db(res, pdu):
if not res:
# Check local db. # Check local db.
new_pdu = yield self.store.get_event( return self.store.get_event(
pdu.event_id, pdu.event_id,
allow_rejected=True, allow_rejected=True,
allow_none=True, allow_none=True,
) )
if new_pdu: return res
signed_pdus.append(new_pdu)
return
# Check pdu.origin def try_remote(res, pdu):
if pdu.origin != origin: if not res and pdu.origin != origin:
try: return self.get_pdu(
new_pdu = yield self.get_pdu( destinations=[pdu.origin],
destinations=[pdu.origin], event_id=pdu.event_id,
event_id=pdu.event_id, outlier=outlier,
outlier=outlier, timeout=10000,
timeout=10000, ).addErrback(lambda e: None)
) return res
if new_pdu:
signed_pdus.append(new_pdu)
return
except:
pass
def warn(res, pdu):
if not res:
logger.warn( logger.warn(
"Failed to find copy of %s with valid signature", "Failed to find copy of %s with valid signature",
pdu.event_id, pdu.event_id,
) )
return res
yield defer.gatherResults( for pdu, deferred in zip(pdus, deferreds):
[do(pdu) for pdu in pdus], deferred.addCallbacks(
callback, errback, errbackArgs=[pdu]
).addCallback(
try_local_db, pdu
).addCallback(
try_remote, pdu
).addCallback(
warn, pdu
)
valid_pdus = yield defer.gatherResults(
deferreds,
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
defer.returnValue(signed_pdus) if include_none:
defer.returnValue(valid_pdus)
else:
defer.returnValue([p for p in valid_pdus if p])
@defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu): def _check_sigs_and_hash(self, pdu):
"""Throws a SynapseError if the PDU does not have the correct return self._check_sigs_and_hashes([pdu])[0]
def _check_sigs_and_hashes(self, pdus):
"""Throws a SynapseError if a PDU does not have the correct
signatures. signatures.
Returns: Returns:
FrozenEvent: Either the given event or it redacted if it failed the FrozenEvent: Either the given event or it redacted if it failed the
content hash check. content hash check.
""" """
# Check signatures are correct.
redacted_event = prune_event(pdu)
redacted_pdu_json = redacted_event.get_pdu_json()
try: redacted_pdus = [
yield self.keyring.verify_json_for_server( prune_event(pdu)
pdu.origin, redacted_pdu_json for pdu in pdus
) ]
except SynapseError:
deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])
def callback(_, pdu, redacted):
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json()
)
return redacted
return pdu
def errback(failure, pdu):
failure.trap(SynapseError)
logger.warn( logger.warn(
"Signature check failed for %s", "Signature check failed for %s",
pdu.event_id, pdu.event_id,
) )
raise return failure
if not check_event_content_hash(pdu): for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
logger.warn( deferred.addCallbacks(
"Event content has been tampered, redacting.", callback, errback,
pdu.event_id, callbackArgs=[pdu, redacted],
errbackArgs=[pdu],
) )
defer.returnValue(redacted_event)
defer.returnValue(pdu) return deferreds

View File

@ -30,6 +30,7 @@ import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
import copy
import itertools import itertools
import logging import logging
import random import random
@ -167,7 +168,7 @@ class FederationClient(FederationBase):
# FIXME: We should handle signature failures more gracefully. # FIXME: We should handle signature failures more gracefully.
pdus[:] = yield defer.gatherResults( pdus[:] = yield defer.gatherResults(
[self._check_sigs_and_hash(pdu) for pdu in pdus], self._check_sigs_and_hashes(pdus),
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
@ -230,7 +231,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0] pdu = pdu_list[0]
# Check signatures are correct. # Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu) pdu = yield self._check_sigs_and_hashes([pdu])[0]
break break
@ -327,6 +328,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def make_join(self, destinations, room_id, user_id): def make_join(self, destinations, room_id, user_id):
for destination in destinations: for destination in destinations:
if destination == self.server_name:
continue
try: try:
ret = yield self.transport_layer.make_join( ret = yield self.transport_layer.make_join(
destination, room_id, user_id destination, room_id, user_id
@ -353,6 +357,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def send_join(self, destinations, pdu): def send_join(self, destinations, pdu):
for destination in destinations: for destination in destinations:
if destination == self.server_name:
continue
try: try:
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join( _, content = yield self.transport_layer.send_join(
@ -374,17 +381,39 @@ class FederationClient(FederationBase):
for p in content.get("auth_chain", []) for p in content.get("auth_chain", [])
] ]
signed_state, signed_auth = yield defer.gatherResults( pdus = {
[ p.event_id: p
self._check_sigs_and_hash_and_fetch( for p in itertools.chain(state, auth_chain)
destination, state, outlier=True }
),
self._check_sigs_and_hash_and_fetch( valid_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True destination, pdus.values(),
) outlier=True,
], )
consumeErrors=True
).addErrback(unwrapFirstError) valid_pdus_map = {
p.event_id: p
for p in valid_pdus
}
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
signed_state = [
copy.copy(valid_pdus_map[p.event_id])
for p in state
if p.event_id in valid_pdus_map
]
signed_auth = [
valid_pdus_map[p.event_id]
for p in auth_chain
if p.event_id in valid_pdus_map
]
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
for s in signed_state:
s.internal_metadata = copy.deepcopy(s.internal_metadata)
auth_chain.sort(key=lambda e: e.depth) auth_chain.sort(key=lambda e: e.depth)
@ -396,7 +425,7 @@ class FederationClient(FederationBase):
except CodeMessageException: except CodeMessageException:
raise raise
except Exception as e: except Exception as e:
logger.warn( logger.exception(
"Failed to send_join via %s: %s", "Failed to send_join via %s: %s",
destination, e.message destination, e.message
) )

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from _base import SQLBaseStore from _base import SQLBaseStore, cached
from twisted.internet import defer from twisted.internet import defer
@ -71,6 +71,25 @@ class KeyStore(SQLBaseStore):
desc="store_server_certificate", desc="store_server_certificate",
) )
@cached()
@defer.inlineCallbacks
def get_all_server_verify_keys(self, server_name):
rows = yield self._simple_select_list(
table="server_signature_keys",
keyvalues={
"server_name": server_name,
},
retcols=["key_id", "verify_key"],
desc="get_all_server_verify_keys",
)
defer.returnValue({
row["key_id"]: decode_verify_key_bytes(
row["key_id"], str(row["verify_key"])
)
for row in rows
})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_keys(self, server_name, key_ids): def get_server_verify_keys(self, server_name, key_ids):
"""Retrieve the NACL verification key for a given server for the given """Retrieve the NACL verification key for a given server for the given
@ -81,24 +100,14 @@ class KeyStore(SQLBaseStore):
Returns: Returns:
(list of VerifyKey): The verification keys. (list of VerifyKey): The verification keys.
""" """
sql = ( keys = yield self.get_all_server_verify_keys(server_name)
"SELECT key_id, verify_key FROM server_signature_keys" defer.returnValue({
" WHERE server_name = ?" k: keys[k]
" AND key_id in (" + ",".join("?" for key_id in key_ids) + ")" for k in key_ids
) if k in keys and keys[k]
})
rows = yield self._execute_and_decode(
"get_server_verify_keys", sql, server_name, *key_ids
)
keys = []
for row in rows:
key_id = row["key_id"]
key_bytes = row["verify_key"]
key = decode_verify_key_bytes(key_id, str(key_bytes))
keys.append(key)
defer.returnValue(keys)
@defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms, def store_server_verify_key(self, server_name, from_server, time_now_ms,
verify_key): verify_key):
"""Stores a NACL verification key for the given server. """Stores a NACL verification key for the given server.
@ -109,7 +118,7 @@ class KeyStore(SQLBaseStore):
ts_now_ms (int): The time now in milliseconds ts_now_ms (int): The time now in milliseconds
verification_key (VerifyKey): The NACL verify key. verification_key (VerifyKey): The NACL verify key.
""" """
return self._simple_upsert( yield self._simple_upsert(
table="server_signature_keys", table="server_signature_keys",
keyvalues={ keyvalues={
"server_name": server_name, "server_name": server_name,
@ -123,6 +132,8 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key", desc="store_server_verify_key",
) )
self.get_all_server_verify_keys.invalidate(server_name)
def store_server_keys_json(self, server_name, key_id, from_server, def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes): ts_now_ms, ts_expires_ms, key_json_bytes):
"""Stores the JSON bytes for a set of keys from a server """Stores the JSON bytes for a set of keys from a server
@ -152,6 +163,7 @@ class KeyStore(SQLBaseStore):
"ts_valid_until_ms": ts_expires_ms, "ts_valid_until_ms": ts_expires_ms,
"key_json": buffer(key_json_bytes), "key_json": buffer(key_json_bytes),
}, },
desc="store_server_keys_json",
) )
def get_server_keys_json(self, server_keys): def get_server_keys_json(self, server_keys):