Store key validity time in the storage layer

This is a first step to checking that the key is valid at the required moment.

The idea here is that, rather than passing VerifyKey objects in and out of the
storage layer, we instead pass FetchKeyResult objects, which simply wrap the
VerifyKey and add a valid_until_ts field.
This commit is contained in:
Richard van der Hoff 2019-04-03 18:10:24 +01:00
parent 84660d91b2
commit b75537beaf
6 changed files with 122 additions and 46 deletions

1
changelog.d/5237.misc Normal file
View File

@ -0,0 +1 @@
Store key validity time in the storage layer.

View File

@ -20,7 +20,6 @@ from collections import namedtuple
from six import raise_from from six import raise_from
from six.moves import urllib from six.moves import urllib
import nacl.signing
from signedjson.key import ( from signedjson.key import (
decode_verify_key_bytes, decode_verify_key_bytes,
encode_verify_key_base64, encode_verify_key_base64,
@ -43,6 +42,7 @@ from synapse.api.errors import (
RequestSendFailed, RequestSendFailed,
SynapseError, SynapseError,
) )
from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
from synapse.util.logcontext import ( from synapse.util.logcontext import (
LoggingContext, LoggingContext,
@ -307,11 +307,15 @@ class Keyring(object):
# complete this VerifyKeyRequest. # complete this VerifyKeyRequest.
result_keys = results.get(server_name, {}) result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids: for key_id in verify_request.key_ids:
key = result_keys.get(key_id) fetch_key_result = result_keys.get(key_id)
if key: if fetch_key_result:
with PreserveLoggingContext(): with PreserveLoggingContext():
verify_request.deferred.callback( verify_request.deferred.callback(
(server_name, key_id, key) (
server_name,
key_id,
fetch_key_result.verify_key,
)
) )
break break
else: else:
@ -348,12 +352,12 @@ class Keyring(object):
def get_keys_from_store(self, server_name_and_key_ids): def get_keys_from_store(self, server_name_and_key_ids):
""" """
Args: Args:
server_name_and_key_ids (iterable(Tuple[str, iterable[str]]): server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for list of (server_name, iterable[key_id]) tuples to fetch keys for
Returns: Returns:
Deferred: resolves to dict[str, dict[str, VerifyKey|None]]: map from Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
server_name -> key_id -> VerifyKey map from server_name -> key_id -> FetchKeyResult
""" """
keys_to_fetch = ( keys_to_fetch = (
(server_name, key_id) (server_name, key_id)
@ -430,6 +434,18 @@ class Keyring(object):
def get_server_verify_key_v2_indirect( def get_server_verify_key_v2_indirect(
self, server_names_and_key_ids, perspective_name, perspective_keys self, server_names_and_key_ids, perspective_name, perspective_keys
): ):
"""
Args:
server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for
perspective_name (str): name of the notary server to query for the keys
perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
notary server
Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map
from server_name -> key_id -> FetchKeyResult
"""
# TODO(mark): Set the minimum_valid_until_ts to that needed by # TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating # the events being validated or the current time if validating
# an incoming request. # an incoming request.
@ -506,7 +522,7 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids): def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {} # type: dict[str, nacl.signing.VerifyKey] keys = {} # type: dict[str, FetchKeyResult]
for requested_key_id in key_ids: for requested_key_id in key_ids:
if requested_key_id in keys: if requested_key_id in keys:
@ -583,9 +599,9 @@ class Keyring(object):
actually in the response actually in the response
Returns: Returns:
Deferred[dict[str, nacl.signing.VerifyKey]]: Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
map from key_id to key object
""" """
ts_valid_until_ms = response_json[u"valid_until_ts"]
# start by extracting the keys from the response, since they may be required # start by extracting the keys from the response, since they may be required
# to validate the signature on the response. # to validate the signature on the response.
@ -595,7 +611,9 @@ class Keyring(object):
key_base64 = key_data["key"] key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64) key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes) verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_keys[key_id] = verify_key verify_keys[key_id] = FetchKeyResult(
verify_key=verify_key, valid_until_ts=ts_valid_until_ms
)
# TODO: improve this signature checking # TODO: improve this signature checking
server_name = response_json["server_name"] server_name = response_json["server_name"]
@ -606,7 +624,7 @@ class Keyring(object):
) )
verify_signed_json( verify_signed_json(
response_json, server_name, verify_keys[key_id] response_json, server_name, verify_keys[key_id].verify_key
) )
for key_id, key_data in response_json["old_verify_keys"].items(): for key_id, key_data in response_json["old_verify_keys"].items():
@ -614,7 +632,9 @@ class Keyring(object):
key_base64 = key_data["key"] key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64) key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes) verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_keys[key_id] = verify_key verify_keys[key_id] = FetchKeyResult(
verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
)
# re-sign the json with our own key, so that it is ready if we are asked to # re-sign the json with our own key, so that it is ready if we are asked to
# give it out as a notary server # give it out as a notary server
@ -623,7 +643,6 @@ class Keyring(object):
) )
signed_key_json_bytes = encode_canonical_json(signed_key_json) signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
# for reasons I don't quite understand, we store this json for the key ids we # for reasons I don't quite understand, we store this json for the key ids we
# requested, as well as those we got. # requested, as well as those we got.

View File

@ -19,6 +19,7 @@ import logging
import six import six
import attr
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from synapse.util import batch_iter from synapse.util import batch_iter
@ -36,6 +37,12 @@ else:
db_binary_type = memoryview db_binary_type = memoryview
@attr.s(slots=True, frozen=True)
class FetchKeyResult(object):
verify_key = attr.ib() # VerifyKey: the key itself
valid_until_ts = attr.ib() # int: how long we can use this key for
class KeyStore(SQLBaseStore): class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys """Persistence for signature verification keys
""" """
@ -54,8 +61,8 @@ class KeyStore(SQLBaseStore):
iterable of (server_name, key-id) tuples to fetch keys for iterable of (server_name, key-id) tuples to fetch keys for
Returns: Returns:
Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]: Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
map from (server_name, key_id) -> VerifyKey, or None if the key is map from (server_name, key_id) -> FetchKeyResult, or None if the key is
unknown unknown
""" """
keys = {} keys = {}
@ -65,17 +72,19 @@ class KeyStore(SQLBaseStore):
# batch_iter always returns tuples so it's safe to do len(batch) # batch_iter always returns tuples so it's safe to do len(batch)
sql = ( sql = (
"SELECT server_name, key_id, verify_key FROM server_signature_keys " "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
"WHERE 1=0" "FROM server_signature_keys WHERE 1=0"
) + " OR (server_name=? AND key_id=?)" * len(batch) ) + " OR (server_name=? AND key_id=?)" * len(batch)
txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
for row in txn: for row in txn:
server_name, key_id, key_bytes = row server_name, key_id, key_bytes, ts_valid_until_ms = row
keys[(server_name, key_id)] = decode_verify_key_bytes( res = FetchKeyResult(
key_id, bytes(key_bytes) verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
valid_until_ts=ts_valid_until_ms,
) )
keys[(server_name, key_id)] = res
def _txn(txn): def _txn(txn):
for batch in batch_iter(server_name_and_key_ids, 50): for batch in batch_iter(server_name_and_key_ids, 50):
@ -89,20 +98,21 @@ class KeyStore(SQLBaseStore):
Args: Args:
from_server (str): Where the verification keys were looked up from_server (str): Where the verification keys were looked up
ts_added_ms (int): The time to record that the key was added ts_added_ms (int): The time to record that the key was added
verify_keys (iterable[tuple[str, str, nacl.signing.VerifyKey]]): verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
keys to be stored. Each entry is a triplet of keys to be stored. Each entry is a triplet of
(server_name, key_id, key). (server_name, key_id, key).
""" """
key_values = [] key_values = []
value_values = [] value_values = []
invalidations = [] invalidations = []
for server_name, key_id, verify_key in verify_keys: for server_name, key_id, fetch_result in verify_keys:
key_values.append((server_name, key_id)) key_values.append((server_name, key_id))
value_values.append( value_values.append(
( (
from_server, from_server,
ts_added_ms, ts_added_ms,
db_binary_type(verify_key.encode()), fetch_result.valid_until_ts,
db_binary_type(fetch_result.verify_key.encode()),
) )
) )
# invalidate takes a tuple corresponding to the params of # invalidate takes a tuple corresponding to the params of
@ -125,6 +135,7 @@ class KeyStore(SQLBaseStore):
value_names=( value_names=(
"from_server", "from_server",
"ts_added_ms", "ts_added_ms",
"ts_valid_until_ms",
"verify_key", "verify_key",
), ),
value_values=value_values, value_values=value_values,

View File

@ -0,0 +1,23 @@
/* Copyright 2019 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* When we can use this key until, before we have to refresh it. */
ALTER TABLE server_signature_keys ADD COLUMN ts_valid_until_ms BIGINT;
UPDATE server_signature_keys SET ts_valid_until_ms = (
SELECT MAX(ts_valid_until_ms) FROM server_keys_json skj WHERE
skj.server_name = server_signature_keys.server_name AND
skj.key_id = server_signature_keys.key_id
);

View File

@ -25,6 +25,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.crypto import keyring from synapse.crypto import keyring
from synapse.crypto.keyring import KeyLookupError from synapse.crypto.keyring import KeyLookupError
from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
@ -201,7 +202,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
( (
"server9", "server9",
key1_id, key1_id,
signedjson.key.get_verify_key(key1), FetchKeyResult(signedjson.key.get_verify_key(key1), 1000),
), ),
], ],
) )
@ -251,9 +252,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
server_name_and_key_ids = [(SERVER_NAME, ("key1",))] server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids)) keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
k = keys[SERVER_NAME][testverifykey_id] k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k, testverifykey) self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.alg, "ed25519") self.assertEqual(k.verify_key, testverifykey)
self.assertEqual(k.version, "ver1") self.assertEqual(k.verify_key.alg, "ed25519")
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated # check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None) lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@ -321,9 +323,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids)) keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
self.assertIn(SERVER_NAME, keys) self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id] k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k, testverifykey) self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.alg, "ed25519") self.assertEqual(k.verify_key, testverifykey)
self.assertEqual(k.version, "ver1") self.assertEqual(k.verify_key.alg, "ed25519")
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated # check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None) lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@ -346,7 +349,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def run_in_context(f, *args, **kwargs): def run_in_context(f, *args, **kwargs):
with LoggingContext("testctx"): with LoggingContext("testctx") as ctx:
# we set the "request" prop to make it easier to follow what's going on in the
# logs.
ctx.request = "testctx"
rv = yield f(*args, **kwargs) rv = yield f(*args, **kwargs)
defer.returnValue(rv) defer.returnValue(rv)

View File

@ -17,6 +17,8 @@ import signedjson.key
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from synapse.storage.keys import FetchKeyResult
import tests.unittest import tests.unittest
KEY_1 = signedjson.key.decode_verify_key_base64( KEY_1 = signedjson.key.decode_verify_key_base64(
@ -37,8 +39,8 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
"from_server", "from_server",
10, 10,
[ [
("server1", key_id_1, KEY_1), ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
("server1", key_id_2, KEY_2), ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
], ],
) )
self.get_success(d) self.get_success(d)
@ -50,13 +52,15 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(len(res.keys()), 3) self.assertEqual(len(res.keys()), 3)
res1 = res[("server1", key_id_1)] res1 = res[("server1", key_id_1)]
self.assertEqual(res1, KEY_1) self.assertEqual(res1.verify_key, KEY_1)
self.assertEqual(res1.version, "key1") self.assertEqual(res1.verify_key.version, "key1")
self.assertEqual(res1.valid_until_ts, 100)
res2 = res[("server1", key_id_2)] res2 = res[("server1", key_id_2)]
self.assertEqual(res2, KEY_2) self.assertEqual(res2.verify_key, KEY_2)
# version comes from the ID it was stored with # version comes from the ID it was stored with
self.assertEqual(res2.version, "KEY_ID_2") self.assertEqual(res2.verify_key.version, "KEY_ID_2")
self.assertEqual(res2.valid_until_ts, 200)
# non-existent result gives None # non-existent result gives None
self.assertIsNone(res[("server1", "ed25519:key3")]) self.assertIsNone(res[("server1", "ed25519:key3")])
@ -73,8 +77,8 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
"from_server", "from_server",
0, 0,
[ [
("srv1", key_id_1, KEY_1), ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
("srv1", key_id_2, KEY_2), ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
], ],
) )
self.get_success(d) self.get_success(d)
@ -82,26 +86,38 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
res = self.get_success(d) res = self.get_success(d)
self.assertEqual(len(res.keys()), 2) self.assertEqual(len(res.keys()), 2)
self.assertEqual(res[("srv1", key_id_1)], KEY_1)
self.assertEqual(res[("srv1", key_id_2)], KEY_2) res1 = res[("srv1", key_id_1)]
self.assertEqual(res1.verify_key, KEY_1)
self.assertEqual(res1.valid_until_ts, 100)
res2 = res[("srv1", key_id_2)]
self.assertEqual(res2.verify_key, KEY_2)
self.assertEqual(res2.valid_until_ts, 200)
# we should be able to look up the same thing again without a db hit # we should be able to look up the same thing again without a db hit
res = store.get_server_verify_keys([("srv1", key_id_1)]) res = store.get_server_verify_keys([("srv1", key_id_1)])
if isinstance(res, Deferred): if isinstance(res, Deferred):
res = self.successResultOf(res) res = self.successResultOf(res)
self.assertEqual(len(res.keys()), 1) self.assertEqual(len(res.keys()), 1)
self.assertEqual(res[("srv1", key_id_1)], KEY_1) self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
new_key_2 = signedjson.key.get_verify_key( new_key_2 = signedjson.key.get_verify_key(
signedjson.key.generate_signing_key("key2") signedjson.key.generate_signing_key("key2")
) )
d = store.store_server_verify_keys( d = store.store_server_verify_keys(
"from_server", 10, [("srv1", key_id_2, new_key_2)] "from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))]
) )
self.get_success(d) self.get_success(d)
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
res = self.get_success(d) res = self.get_success(d)
self.assertEqual(len(res.keys()), 2) self.assertEqual(len(res.keys()), 2)
self.assertEqual(res[("srv1", key_id_1)], KEY_1)
self.assertEqual(res[("srv1", key_id_2)], new_key_2) res1 = res[("srv1", key_id_1)]
self.assertEqual(res1.verify_key, KEY_1)
self.assertEqual(res1.valid_until_ts, 100)
res2 = res[("srv1", key_id_2)]
self.assertEqual(res2.verify_key, new_key_2)
self.assertEqual(res2.valid_until_ts, 300)