Merge pull request #194 from matrix-org/erikj/bulk_verify_sigs

Implement bulk verify_signed_json API
This commit is contained in:
Erik Johnston 2015-07-10 13:46:53 +01:00
commit 0b3389bcd2
4 changed files with 492 additions and 215 deletions

View File

@ -25,11 +25,13 @@ from syutil.base64util import decode_base64, encode_base64
from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred
from OpenSSL import crypto
from collections import namedtuple
import urllib
import hashlib
import logging
@ -38,6 +40,9 @@ import logging
logger = logging.getLogger(__name__)
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
class Keyring(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@ -49,18 +54,55 @@ class Keyring(object):
self.key_downloads = {}
@defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object):
return self.verify_json_objects_for_server(
[(server_name, json_object)]
)[0]
def verify_json_objects_for_server(self, server_and_json):
"""Bulk verfies signatures of json objects, bulk fetching keys as
necessary.
Args:
server_and_json (list): List of pairs of (server_name, json_object)
Returns:
list of deferreds indicating success or failure to verify each
json object's signature for the given server_name.
"""
group_id_to_json = {}
group_id_to_group = {}
group_ids = []
next_group_id = 0
deferreds = {}
for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name)
group_id = next_group_id
next_group_id += 1
group_ids.append(group_id)
key_ids = signature_ids(json_object, server_name)
if not key_ids:
raise SynapseError(
deferreds[group_id] = defer.fail(SynapseError(
400,
"Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
)
))
else:
deferreds[group_id] = defer.Deferred()
group = KeyGroup(server_name, group_id, key_ids)
group_id_to_group[group_id] = group
group_id_to_json[group_id] = json_object
@defer.inlineCallbacks
def handle_key_deferred(group, deferred):
server_name = group.server_name
try:
verify_key = yield self.get_server_verify_key(server_name, key_ids)
_, _, key_id, verify_key = yield deferred
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
@ -72,7 +114,7 @@ class Keyring(object):
Codes.UNAUTHORIZED,
)
except Exception as e:
logger.warn(
logger.exception(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
@ -82,6 +124,8 @@ class Keyring(object):
Codes.UNAUTHORIZED,
)
json_object = group_id_to_json[group.group_id]
try:
verify_signed_json(json_object, server_name, verify_key)
except:
@ -93,79 +137,208 @@ class Keyring(object):
Codes.UNAUTHORIZED,
)
@defer.inlineCallbacks
def get_server_verify_key(self, server_name, key_ids):
"""Finds a verification key for the server with one of the key ids.
Trys to fetch the key from a trusted perspective server first.
Args:
server_name(str): The name of the server to fetch a key for.
keys_ids (list of str): The key_ids to check for.
"""
cached = yield self.store.get_server_verify_keys(server_name, key_ids)
server_to_deferred = {
server_name: defer.Deferred()
for server_name, _ in server_and_json
}
if cached:
defer.returnValue(cached[0])
return
download = self.key_downloads.get(server_name)
if download is None:
download = self._get_server_verify_key_impl(server_name, key_ids)
download = ObservableDeferred(
download,
consumeErrors=True
# We want to wait for any previous lookups to complete before
# proceeding.
wait_on_deferred = self.wait_for_previous_lookups(
[server_name for server_name, _ in server_and_json],
server_to_deferred,
)
self.key_downloads[server_name] = download
@download.addBoth
def callback(ret):
del self.key_downloads[server_name]
return ret
# Actually start fetching keys.
wait_on_deferred.addBoth(
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
)
r = yield download.observe()
defer.returnValue(r)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
server_to_gids = {}
def remove_deferreds(res, server_name, group_id):
server_to_gids[server_name].discard(group_id)
if not server_to_gids[server_name]:
server_to_deferred.pop(server_name).callback(None)
return res
for g_id, deferred in deferreds.items():
server_name = group_id_to_group[g_id].server_name
server_to_gids.setdefault(server_name, set()).add(g_id)
deferred.addBoth(remove_deferreds, server_name, g_id)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
return [
handle_key_deferred(
group_id_to_group[g_id],
deferreds[g_id],
)
for g_id in group_ids
]
@defer.inlineCallbacks
def _get_server_verify_key_impl(self, server_name, key_ids):
keys = None
def wait_for_previous_lookups(self, server_names, server_to_deferred):
"""Waits for any previous key lookups for the given servers to finish.
Args:
server_names (list): list of server_names we want to lookup
server_to_deferred (dict): server_name to deferred which gets
resolved once we've finished looking up keys for that server
"""
while True:
wait_on = [
self.key_downloads[server_name]
for server_name in server_names
if server_name in self.key_downloads
]
if wait_on:
yield defer.DeferredList(wait_on)
else:
break
for server_name, deferred in server_to_deferred:
self.key_downloads[server_name] = ObservableDeferred(deferred)
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
"""Takes a dict of KeyGroups and tries to find at least one key for
each group.
"""
# These are functions that produce keys given a list of key ids
key_fetch_fns = (
self.get_keys_from_store, # First try the local store
self.get_keys_from_perspectives, # Then try via perspectives
self.get_keys_from_server, # Then try directly
)
@defer.inlineCallbacks
def do_iterations():
merged_results = {}
missing_keys = {
group.server_name: key_id
for group in group_id_to_group.values()
for key_id in group.key_ids
}
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
# We now need to figure out which groups we have keys for
# and which we don't
missing_groups = {}
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
group_id_to_deferred[group.group_id].callback((
group.group_id,
group.server_name,
key_id,
merged_results[group.server_name][key_id],
))
break
else:
missing_groups.setdefault(
group.server_name, []
).append(group)
if not missing_groups:
break
missing_keys = {
server_name: set(
key_id for group in groups for key_id in group.key_ids
)
for server_name, groups in missing_groups.items()
}
for group in missing_groups.values():
group_id_to_deferred[group.group_id].errback(SynapseError(
401,
"No key for %s with id %s" % (
group.server_name, group.key_ids,
),
Codes.UNAUTHORIZED,
))
def on_err(err):
for deferred in group_id_to_deferred.values():
if not deferred.called:
deferred.errback(err)
do_iterations().addErrback(on_err)
return group_id_to_deferred
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults(
[
self.store.get_server_verify_keys(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
defer.returnValue(dict(zip(
[server_name for server_name, _ in server_name_and_key_ids],
res
)))
@defer.inlineCallbacks
def get_keys_from_perspectives(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
result = yield self.get_server_verify_key_v2_indirect(
server_name, key_ids, perspective_name, perspective_keys
server_name_and_key_ids, perspective_name, perspective_keys
)
defer.returnValue(result)
except Exception as e:
logging.info(
"Unable to getting key %r for %r from %r: %s %s",
key_ids, server_name, perspective_name,
logger.exception(
"Unable to get key from %r: %s %s",
perspective_name,
type(e).__name__, str(e.message),
)
defer.returnValue({})
perspective_results = yield defer.gatherResults([
results = yield defer.gatherResults(
[
get_key(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items()
])
],
consumeErrors=True,
).addErrback(unwrapFirstError)
for results in perspective_results:
if results is not None:
keys = results
union_of_keys = {}
for result in results:
for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys)
defer.returnValue(union_of_keys)
@defer.inlineCallbacks
def get_keys_from_server(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(server_name, key_ids):
limiter = yield get_retry_limiter(
server_name,
self.clock,
self.store,
)
with limiter:
if not keys:
keys = None
try:
keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids
)
except Exception as e:
logging.info(
logger.info(
"Unable to getting key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
@ -176,14 +349,30 @@ class Keyring(object):
server_name, key_ids
)
for key_id in key_ids:
if key_id in keys:
defer.returnValue(keys[key_id])
return
raise ValueError("No verification key found for given key ids")
keys = {server_name: keys}
defer.returnValue(keys)
results = yield defer.gatherResults(
[
get_key(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
merged = {}
for result in results:
merged.update(result)
defer.returnValue({
server_name: keys
for server_name, keys in merged.items()
if keys
})
@defer.inlineCallbacks
def get_server_verify_key_v2_indirect(self, server_name, key_ids,
def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
perspective_name,
perspective_keys):
limiter = yield get_retry_limiter(
@ -204,6 +393,7 @@ class Keyring(object):
u"minimum_valid_until_ts": 0
} for key_id in key_ids
}
for server_name, key_ids in server_names_and_key_ids
}
},
)
@ -243,23 +433,29 @@ class Keyring(object):
" server %r" % (perspective_name,)
)
response_keys = yield self.process_v2_response(
server_name, perspective_name, response
processed_response = yield self.process_v2_response(
perspective_name, response
)
keys.update(response_keys)
for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys)
yield self.store_keys(
yield defer.gatherResults(
[
self.store_keys(
server_name=server_name,
from_server=perspective_name,
verify_keys=keys,
verify_keys=response_keys,
)
for server_name, response_keys in keys.items()
],
consumeErrors=True
).addErrback(unwrapFirstError)
defer.returnValue(keys)
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {}
for requested_key_id in key_ids:
@ -295,25 +491,30 @@ class Keyring(object):
raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response(
server_name=server_name,
from_server=server_name,
requested_id=requested_key_id,
requested_ids=[requested_key_id],
response_json=response,
)
keys.update(response_keys)
yield self.store_keys(
server_name=server_name,
yield defer.gatherResults(
[
self.store_keys(
server_name=key_server_name,
from_server=server_name,
verify_keys=keys,
verify_keys=verify_keys,
)
for key_server_name, verify_keys in keys.items()
],
consumeErrors=True
).addErrback(unwrapFirstError)
defer.returnValue(keys)
@defer.inlineCallbacks
def process_v2_response(self, server_name, from_server, response_json,
requested_id=None):
def process_v2_response(self, from_server, response_json,
requested_ids=[]):
time_now_ms = self.clock.time_msec()
response_keys = {}
verify_keys = {}
@ -335,6 +536,8 @@ class Keyring(object):
verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key
results = {}
server_name = response_json["server_name"]
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
raise ValueError(
@ -357,17 +560,16 @@ class Keyring(object):
signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
updated_key_ids = set()
if requested_id is not None:
updated_key_ids.add(requested_id)
updated_key_ids = set(requested_ids)
updated_key_ids.update(verify_keys)
updated_key_ids.update(old_verify_keys)
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
for key_id in updated_key_ids:
yield self.store.store_server_keys_json(
yield defer.gatherResults(
[
self.store.store_server_keys_json(
server_name=server_name,
key_id=key_id,
from_server=server_name,
@ -375,10 +577,14 @@ class Keyring(object):
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
)
for key_id in updated_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
defer.returnValue(response_keys)
results[server_name] = response_keys
raise ValueError("No verification key found for given key ids")
defer.returnValue(results)
@defer.inlineCallbacks
def get_server_verify_key_v1_direct(self, server_name, key_ids):
@ -462,8 +668,13 @@ class Keyring(object):
Returns:
A deferred that completes when the keys are stored.
"""
for key_id, key in verify_keys.items():
# TODO(markjh): Store whether the keys have expired.
yield self.store.store_server_verify_key(
yield defer.gatherResults(
[
self.store.store_server_verify_key(
server_name, server_name, key.time_added, key
)
for key_id, key in verify_keys.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)

View File

@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
class FederationBase(object):
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
include_none=False):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
@ -50,84 +51,108 @@ class FederationBase(object):
Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
"""
deferreds = self._check_sigs_and_hashes(pdus)
signed_pdus = []
def callback(pdu):
return pdu
@defer.inlineCallbacks
def do(pdu):
try:
new_pdu = yield self._check_sigs_and_hash(pdu)
signed_pdus.append(new_pdu)
except SynapseError:
# FIXME: We should handle signature failures more gracefully.
def errback(failure, pdu):
failure.trap(SynapseError)
return None
def try_local_db(res, pdu):
if not res:
# Check local db.
new_pdu = yield self.store.get_event(
return self.store.get_event(
pdu.event_id,
allow_rejected=True,
allow_none=True,
)
if new_pdu:
signed_pdus.append(new_pdu)
return
return res
# Check pdu.origin
if pdu.origin != origin:
try:
new_pdu = yield self.get_pdu(
def try_remote(res, pdu):
if not res and pdu.origin != origin:
return self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
outlier=outlier,
timeout=10000,
)
if new_pdu:
signed_pdus.append(new_pdu)
return
except:
pass
).addErrback(lambda e: None)
return res
def warn(res, pdu):
if not res:
logger.warn(
"Failed to find copy of %s with valid signature",
pdu.event_id,
)
return res
yield defer.gatherResults(
[do(pdu) for pdu in pdus],
for pdu, deferred in zip(pdus, deferreds):
deferred.addCallbacks(
callback, errback, errbackArgs=[pdu]
).addCallback(
try_local_db, pdu
).addCallback(
try_remote, pdu
).addCallback(
warn, pdu
)
valid_pdus = yield defer.gatherResults(
deferreds,
consumeErrors=True
).addErrback(unwrapFirstError)
defer.returnValue(signed_pdus)
if include_none:
defer.returnValue(valid_pdus)
else:
defer.returnValue([p for p in valid_pdus if p])
@defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu):
"""Throws a SynapseError if the PDU does not have the correct
return self._check_sigs_and_hashes([pdu])[0]
def _check_sigs_and_hashes(self, pdus):
"""Throws a SynapseError if a PDU does not have the correct
signatures.
Returns:
FrozenEvent: Either the given event or it redacted if it failed the
content hash check.
"""
# Check signatures are correct.
redacted_event = prune_event(pdu)
redacted_pdu_json = redacted_event.get_pdu_json()
try:
yield self.keyring.verify_json_for_server(
pdu.origin, redacted_pdu_json
redacted_pdus = [
prune_event(pdu)
for pdu in pdus
]
deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])
def callback(_, pdu, redacted):
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json()
)
except SynapseError:
return redacted
return pdu
def errback(failure, pdu):
failure.trap(SynapseError)
logger.warn(
"Signature check failed for %s",
pdu.event_id,
)
raise
return failure
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting.",
pdu.event_id,
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
deferred.addCallbacks(
callback, errback,
callbackArgs=[pdu, redacted],
errbackArgs=[pdu],
)
defer.returnValue(redacted_event)
defer.returnValue(pdu)
return deferreds

View File

@ -30,6 +30,7 @@ import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
import copy
import itertools
import logging
import random
@ -167,7 +168,7 @@ class FederationClient(FederationBase):
# FIXME: We should handle signature failures more gracefully.
pdus[:] = yield defer.gatherResults(
[self._check_sigs_and_hash(pdu) for pdu in pdus],
self._check_sigs_and_hashes(pdus),
consumeErrors=True,
).addErrback(unwrapFirstError)
@ -230,7 +231,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu)
pdu = yield self._check_sigs_and_hashes([pdu])[0]
break
@ -327,6 +328,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def make_join(self, destinations, room_id, user_id):
for destination in destinations:
if destination == self.server_name:
continue
try:
ret = yield self.transport_layer.make_join(
destination, room_id, user_id
@ -353,6 +357,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_join(self, destinations, pdu):
for destination in destinations:
if destination == self.server_name:
continue
try:
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
@ -374,17 +381,39 @@ class FederationClient(FederationBase):
for p in content.get("auth_chain", [])
]
signed_state, signed_auth = yield defer.gatherResults(
[
self._check_sigs_and_hash_and_fetch(
destination, state, outlier=True
),
self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
pdus = {
p.event_id: p
for p in itertools.chain(state, auth_chain)
}
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, pdus.values(),
outlier=True,
)
],
consumeErrors=True
).addErrback(unwrapFirstError)
valid_pdus_map = {
p.event_id: p
for p in valid_pdus
}
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
signed_state = [
copy.copy(valid_pdus_map[p.event_id])
for p in state
if p.event_id in valid_pdus_map
]
signed_auth = [
valid_pdus_map[p.event_id]
for p in auth_chain
if p.event_id in valid_pdus_map
]
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
for s in signed_state:
s.internal_metadata = copy.deepcopy(s.internal_metadata)
auth_chain.sort(key=lambda e: e.depth)
@ -396,7 +425,7 @@ class FederationClient(FederationBase):
except CodeMessageException:
raise
except Exception as e:
logger.warn(
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from _base import SQLBaseStore
from _base import SQLBaseStore, cached
from twisted.internet import defer
@ -71,6 +71,25 @@ class KeyStore(SQLBaseStore):
desc="store_server_certificate",
)
@cached()
@defer.inlineCallbacks
def get_all_server_verify_keys(self, server_name):
rows = yield self._simple_select_list(
table="server_signature_keys",
keyvalues={
"server_name": server_name,
},
retcols=["key_id", "verify_key"],
desc="get_all_server_verify_keys",
)
defer.returnValue({
row["key_id"]: decode_verify_key_bytes(
row["key_id"], str(row["verify_key"])
)
for row in rows
})
@defer.inlineCallbacks
def get_server_verify_keys(self, server_name, key_ids):
"""Retrieve the NACL verification key for a given server for the given
@ -81,24 +100,14 @@ class KeyStore(SQLBaseStore):
Returns:
(list of VerifyKey): The verification keys.
"""
sql = (
"SELECT key_id, verify_key FROM server_signature_keys"
" WHERE server_name = ?"
" AND key_id in (" + ",".join("?" for key_id in key_ids) + ")"
)
rows = yield self._execute_and_decode(
"get_server_verify_keys", sql, server_name, *key_ids
)
keys = []
for row in rows:
key_id = row["key_id"]
key_bytes = row["verify_key"]
key = decode_verify_key_bytes(key_id, str(key_bytes))
keys.append(key)
defer.returnValue(keys)
keys = yield self.get_all_server_verify_keys(server_name)
defer.returnValue({
k: keys[k]
for k in key_ids
if k in keys and keys[k]
})
@defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms,
verify_key):
"""Stores a NACL verification key for the given server.
@ -109,7 +118,7 @@ class KeyStore(SQLBaseStore):
ts_now_ms (int): The time now in milliseconds
verification_key (VerifyKey): The NACL verify key.
"""
return self._simple_upsert(
yield self._simple_upsert(
table="server_signature_keys",
keyvalues={
"server_name": server_name,
@ -123,6 +132,8 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key",
)
self.get_all_server_verify_keys.invalidate(server_name)
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):
"""Stores the JSON bytes for a set of keys from a server
@ -152,6 +163,7 @@ class KeyStore(SQLBaseStore):
"ts_valid_until_ms": ts_expires_ms,
"key_json": buffer(key_json_bytes),
},
desc="store_server_keys_json",
)
def get_server_keys_json(self, server_keys):