Store key validity time in the storage layer

This is a first step to checking that the key is valid at the required moment.

The idea here is that, rather than passing VerifyKey objects in and out of the
storage layer, we instead pass FetchKeyResult objects, which simply wrap the
VerifyKey and add a valid_until_ts field.
This commit is contained in:
Richard van der Hoff 2019-04-03 18:10:24 +01:00
parent 84660d91b2
commit b75537beaf
6 changed files with 122 additions and 46 deletions

View file

@ -19,6 +19,7 @@ import logging
import six
import attr
from signedjson.key import decode_verify_key_bytes
from synapse.util import batch_iter
@ -36,6 +37,12 @@ else:
db_binary_type = memoryview
@attr.s(slots=True, frozen=True)
class FetchKeyResult(object):
verify_key = attr.ib() # VerifyKey: the key itself
valid_until_ts = attr.ib() # int: how long we can use this key for
class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys
"""
@ -54,8 +61,8 @@ class KeyStore(SQLBaseStore):
iterable of (server_name, key-id) tuples to fetch keys for
Returns:
Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]:
map from (server_name, key_id) -> VerifyKey, or None if the key is
Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
map from (server_name, key_id) -> FetchKeyResult, or None if the key is
unknown
"""
keys = {}
@ -65,17 +72,19 @@ class KeyStore(SQLBaseStore):
# batch_iter always returns tuples so it's safe to do len(batch)
sql = (
"SELECT server_name, key_id, verify_key FROM server_signature_keys "
"WHERE 1=0"
"SELECT server_name, key_id, verify_key, ts_valid_until_ms "
"FROM server_signature_keys WHERE 1=0"
) + " OR (server_name=? AND key_id=?)" * len(batch)
txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
for row in txn:
server_name, key_id, key_bytes = row
keys[(server_name, key_id)] = decode_verify_key_bytes(
key_id, bytes(key_bytes)
server_name, key_id, key_bytes, ts_valid_until_ms = row
res = FetchKeyResult(
verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
valid_until_ts=ts_valid_until_ms,
)
keys[(server_name, key_id)] = res
def _txn(txn):
for batch in batch_iter(server_name_and_key_ids, 50):
@ -89,20 +98,21 @@ class KeyStore(SQLBaseStore):
Args:
from_server (str): Where the verification keys were looked up
ts_added_ms (int): The time to record that the key was added
verify_keys (iterable[tuple[str, str, nacl.signing.VerifyKey]]):
verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
keys to be stored. Each entry is a triplet of
(server_name, key_id, key).
"""
key_values = []
value_values = []
invalidations = []
for server_name, key_id, verify_key in verify_keys:
for server_name, key_id, fetch_result in verify_keys:
key_values.append((server_name, key_id))
value_values.append(
(
from_server,
ts_added_ms,
db_binary_type(verify_key.encode()),
fetch_result.valid_until_ts,
db_binary_type(fetch_result.verify_key.encode()),
)
)
# invalidate takes a tuple corresponding to the params of
@ -125,6 +135,7 @@ class KeyStore(SQLBaseStore):
value_names=(
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"verify_key",
),
value_values=value_values,