mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-14 18:25:33 -04:00
Rewrite the KeyRing (#10035)
This commit is contained in:
parent
3cf6b34b4e
commit
fc3d2dc269
8 changed files with 403 additions and 502 deletions
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
from typing import Dict, List
|
||||
from unittest.mock import Mock
|
||||
|
||||
import attr
|
||||
|
@ -21,7 +22,6 @@ import signedjson.sign
|
|||
from nacl.signing import SigningKey
|
||||
from signedjson.key import encode_verify_key_base64, get_verify_key
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import Deferred, ensureDeferred
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
|
@ -92,23 +92,23 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||
# deferred completes.
|
||||
first_lookup_deferred = Deferred()
|
||||
|
||||
async def first_lookup_fetch(keys_to_fetch):
|
||||
self.assertEquals(current_context().request.id, "context_11")
|
||||
self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
|
||||
async def first_lookup_fetch(
|
||||
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||
) -> Dict[str, FetchKeyResult]:
|
||||
# self.assertEquals(current_context().request.id, "context_11")
|
||||
self.assertEqual(server_name, "server10")
|
||||
self.assertEqual(key_ids, [get_key_id(key1)])
|
||||
self.assertEqual(minimum_valid_until_ts, 0)
|
||||
|
||||
await make_deferred_yieldable(first_lookup_deferred)
|
||||
return {
|
||||
"server10": {
|
||||
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
|
||||
}
|
||||
}
|
||||
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)}
|
||||
|
||||
mock_fetcher.get_keys.side_effect = first_lookup_fetch
|
||||
|
||||
async def first_lookup():
|
||||
with LoggingContext("context_11", request=FakeRequest("context_11")):
|
||||
res_deferreds = kr.verify_json_objects_for_server(
|
||||
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
|
||||
[("server10", json1, 0), ("server11", {}, 0)]
|
||||
)
|
||||
|
||||
# the unsigned json should be rejected pretty quickly
|
||||
|
@ -126,18 +126,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
d0 = ensureDeferred(first_lookup())
|
||||
|
||||
self.pump()
|
||||
|
||||
mock_fetcher.get_keys.assert_called_once()
|
||||
|
||||
# a second request for a server with outstanding requests
|
||||
# should block rather than start a second call
|
||||
|
||||
async def second_lookup_fetch(keys_to_fetch):
|
||||
self.assertEquals(current_context().request.id, "context_12")
|
||||
return {
|
||||
"server10": {
|
||||
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
|
||||
}
|
||||
}
|
||||
async def second_lookup_fetch(
|
||||
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||
) -> Dict[str, FetchKeyResult]:
|
||||
# self.assertEquals(current_context().request.id, "context_12")
|
||||
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)}
|
||||
|
||||
mock_fetcher.get_keys.reset_mock()
|
||||
mock_fetcher.get_keys.side_effect = second_lookup_fetch
|
||||
|
@ -146,7 +146,13 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||
async def second_lookup():
|
||||
with LoggingContext("context_12", request=FakeRequest("context_12")):
|
||||
res_deferreds_2 = kr.verify_json_objects_for_server(
|
||||
[("server10", json1, 0, "test")]
|
||||
[
|
||||
(
|
||||
"server10",
|
||||
json1,
|
||||
0,
|
||||
)
|
||||
]
|
||||
)
|
||||
res_deferreds_2[0].addBoth(self.check_context, None)
|
||||
second_lookup_state[0] = 1
|
||||
|
@ -183,11 +189,11 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||
signedjson.sign.sign_json(json1, "server9", key1)
|
||||
|
||||
# should fail immediately on an unsigned object
|
||||
d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
|
||||
d = kr.verify_json_for_server("server9", {}, 0)
|
||||
self.get_failure(d, SynapseError)
|
||||
|
||||
# should succeed on a signed object
|
||||
d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
|
||||
d = kr.verify_json_for_server("server9", json1, 500)
|
||||
# self.assertFalse(d.called)
|
||||
self.get_success(d)
|
||||
|
||||
|
@ -214,24 +220,24 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||
signedjson.sign.sign_json(json1, "server9", key1)
|
||||
|
||||
# should fail immediately on an unsigned object
|
||||
d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
|
||||
d = kr.verify_json_for_server("server9", {}, 0)
|
||||
self.get_failure(d, SynapseError)
|
||||
|
||||
# should fail on a signed object with a non-zero minimum_valid_until_ms,
|
||||
# as it tries to refetch the keys and fails.
|
||||
d = _verify_json_for_server(
|
||||
kr, "server9", json1, 500, "test signed non-zero min"
|
||||
)
|
||||
d = kr.verify_json_for_server("server9", json1, 500)
|
||||
self.get_failure(d, SynapseError)
|
||||
|
||||
# We expect the keyring tried to refetch the key once.
|
||||
mock_fetcher.get_keys.assert_called_once_with(
|
||||
{"server9": {get_key_id(key1): 500}}
|
||||
"server9", [get_key_id(key1)], 500
|
||||
)
|
||||
|
||||
# should succeed on a signed object with a 0 minimum_valid_until_ms
|
||||
d = _verify_json_for_server(
|
||||
kr, "server9", json1, 0, "test signed with zero min"
|
||||
d = kr.verify_json_for_server(
|
||||
"server9",
|
||||
json1,
|
||||
0,
|
||||
)
|
||||
self.get_success(d)
|
||||
|
||||
|
@ -239,15 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||
"""Two requests for the same key should be deduped."""
|
||||
key1 = signedjson.key.generate_signing_key(1)
|
||||
|
||||
async def get_keys(keys_to_fetch):
|
||||
async def get_keys(
|
||||
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||
) -> Dict[str, FetchKeyResult]:
|
||||
# there should only be one request object (with the max validity)
|
||||
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
|
||||
self.assertEqual(server_name, "server1")
|
||||
self.assertEqual(key_ids, [get_key_id(key1)])
|
||||
self.assertEqual(minimum_valid_until_ts, 1500)
|
||||
|
||||
return {
|
||||
"server1": {
|
||||
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
|
||||
}
|
||||
}
|
||||
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)}
|
||||
|
||||
mock_fetcher = Mock()
|
||||
mock_fetcher.get_keys = Mock(side_effect=get_keys)
|
||||
|
@ -259,7 +265,14 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||
# the first request should succeed; the second should fail because the key
|
||||
# has expired
|
||||
results = kr.verify_json_objects_for_server(
|
||||
[("server1", json1, 500, "test1"), ("server1", json1, 1500, "test2")]
|
||||
[
|
||||
(
|
||||
"server1",
|
||||
json1,
|
||||
500,
|
||||
),
|
||||
("server1", json1, 1500),
|
||||
]
|
||||
)
|
||||
self.assertEqual(len(results), 2)
|
||||
self.get_success(results[0])
|
||||
|
@ -274,19 +287,21 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||
"""If the first fetcher cannot provide a recent enough key, we fall back"""
|
||||
key1 = signedjson.key.generate_signing_key(1)
|
||||
|
||||
async def get_keys1(keys_to_fetch):
|
||||
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
|
||||
return {
|
||||
"server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
|
||||
}
|
||||
async def get_keys1(
|
||||
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||
) -> Dict[str, FetchKeyResult]:
|
||||
self.assertEqual(server_name, "server1")
|
||||
self.assertEqual(key_ids, [get_key_id(key1)])
|
||||
self.assertEqual(minimum_valid_until_ts, 1500)
|
||||
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
|
||||
|
||||
async def get_keys2(keys_to_fetch):
|
||||
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
|
||||
return {
|
||||
"server1": {
|
||||
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
|
||||
}
|
||||
}
|
||||
async def get_keys2(
|
||||
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||
) -> Dict[str, FetchKeyResult]:
|
||||
self.assertEqual(server_name, "server1")
|
||||
self.assertEqual(key_ids, [get_key_id(key1)])
|
||||
self.assertEqual(minimum_valid_until_ts, 1500)
|
||||
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)}
|
||||
|
||||
mock_fetcher1 = Mock()
|
||||
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
|
||||
|
@ -298,7 +313,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||
signedjson.sign.sign_json(json1, "server1", key1)
|
||||
|
||||
results = kr.verify_json_objects_for_server(
|
||||
[("server1", json1, 1200, "test1"), ("server1", json1, 1500, "test2")]
|
||||
[
|
||||
(
|
||||
"server1",
|
||||
json1,
|
||||
1200,
|
||||
),
|
||||
(
|
||||
"server1",
|
||||
json1,
|
||||
1500,
|
||||
),
|
||||
]
|
||||
)
|
||||
self.assertEqual(len(results), 2)
|
||||
self.get_success(results[0])
|
||||
|
@ -349,9 +375,8 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.http_client.get_json.side_effect = get_json
|
||||
|
||||
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
|
||||
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
|
||||
k = keys[SERVER_NAME][testverifykey_id]
|
||||
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
|
||||
k = keys[testverifykey_id]
|
||||
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
||||
self.assertEqual(k.verify_key, testverifykey)
|
||||
self.assertEqual(k.verify_key.alg, "ed25519")
|
||||
|
@ -378,7 +403,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||
# change the server name: the result should be ignored
|
||||
response["server_name"] = "OTHER_SERVER"
|
||||
|
||||
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
|
||||
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
|
||||
self.assertEqual(keys, {})
|
||||
|
||||
|
||||
|
@ -465,10 +490,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
|
||||
|
||||
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
|
||||
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
|
||||
self.assertIn(SERVER_NAME, keys)
|
||||
k = keys[SERVER_NAME][testverifykey_id]
|
||||
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
|
||||
self.assertIn(testverifykey_id, keys)
|
||||
k = keys[testverifykey_id]
|
||||
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
||||
self.assertEqual(k.verify_key, testverifykey)
|
||||
self.assertEqual(k.verify_key.alg, "ed25519")
|
||||
|
@ -515,10 +539,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
|
||||
|
||||
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
|
||||
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
|
||||
self.assertIn(SERVER_NAME, keys)
|
||||
k = keys[SERVER_NAME][testverifykey_id]
|
||||
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
|
||||
self.assertIn(testverifykey_id, keys)
|
||||
k = keys[testverifykey_id]
|
||||
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
||||
self.assertEqual(k.verify_key, testverifykey)
|
||||
self.assertEqual(k.verify_key.alg, "ed25519")
|
||||
|
@ -559,14 +582,13 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
def get_key_from_perspectives(response):
|
||||
fetcher = PerspectivesKeyFetcher(self.hs)
|
||||
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
|
||||
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
|
||||
return self.get_success(fetcher.get_keys(keys_to_fetch))
|
||||
return self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
|
||||
|
||||
# start with a valid response so we can check we are testing the right thing
|
||||
response = build_response()
|
||||
keys = get_key_from_perspectives(response)
|
||||
k = keys[SERVER_NAME][testverifykey_id]
|
||||
k = keys[testverifykey_id]
|
||||
self.assertEqual(k.verify_key, testverifykey)
|
||||
|
||||
# remove the perspectives server's signature
|
||||
|
@ -585,23 +607,3 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||
def get_key_id(key):
|
||||
"""Get the matrix ID tag for a given SigningKey or VerifyKey"""
|
||||
return "%s:%s" % (key.alg, key.version)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def run_in_context(f, *args, **kwargs):
|
||||
with LoggingContext("testctx"):
|
||||
rv = yield f(*args, **kwargs)
|
||||
return rv
|
||||
|
||||
|
||||
def _verify_json_for_server(kr, *args):
|
||||
"""thin wrapper around verify_json_for_server which makes sure it is wrapped
|
||||
with the patched defer.inlineCallbacks.
|
||||
"""
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def v():
|
||||
rv1 = yield kr.verify_json_for_server(*args)
|
||||
return rv1
|
||||
|
||||
return run_in_context(v)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue