Use servlets for /key/ endpoints. (#14229)

To fix the response for unknown endpoints under that prefix.

See MSC3743.
This commit is contained in:
Patrick Cloke 2022-10-20 11:32:47 -04:00 committed by GitHub
parent da2c93d4b6
commit 755bfeee3a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 86 additions and 83 deletions

View file

@ -13,15 +13,20 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, Set
import re
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
from signedjson.sign import sign_json
from synapse.api.errors import Codes, SynapseError
from twisted.web.server import Request
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_integer,
parse_json_object_from_request,
)
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
@ -32,7 +37,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class RemoteKey(DirectServeJsonResource):
class RemoteKey(RestServlet):
"""HTTP resource for retrieving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
@ -88,11 +93,7 @@ class RemoteKey(DirectServeJsonResource):
}
"""
isLeaf = True
def __init__(self, hs: "HomeServer"):
super().__init__()
self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
@ -101,36 +102,48 @@ class RemoteKey(DirectServeJsonResource):
)
self.config = hs.config
async def _async_render_GET(self, request: SynapseRequest) -> None:
assert request.postpath is not None
if len(request.postpath) == 1:
(server,) = request.postpath
query: dict = {server.decode("ascii"): {}}
elif len(request.postpath) == 2:
server, key_id = request.postpath
def register(self, http_server: HttpServer) -> None:
http_server.register_paths(
"GET",
(
re.compile(
"^/_matrix/key/v2/query/(?P<server>[^/]*)(/(?P<key_id>[^/]*))?$"
),
),
self.on_GET,
self.__class__.__name__,
)
http_server.register_paths(
"POST",
(re.compile("^/_matrix/key/v2/query$"),),
self.on_POST,
self.__class__.__name__,
)
async def on_GET(
self, request: Request, server: str, key_id: Optional[str] = None
) -> Tuple[int, JsonDict]:
if server and key_id:
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
arguments = {}
if minimum_valid_until_ts is not None:
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
query = {server.decode("ascii"): {key_id.decode("ascii"): arguments}}
query = {server: {key_id: arguments}}
else:
raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND)
query = {server: {}}
await self.query_keys(request, query, query_remote_on_cache_miss=True)
return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
async def _async_render_POST(self, request: SynapseRequest) -> None:
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
query = content["server_keys"]
await self.query_keys(request, query, query_remote_on_cache_miss=True)
return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
async def query_keys(
self,
request: SynapseRequest,
query: JsonDict,
query_remote_on_cache_miss: bool = False,
) -> None:
self, query: JsonDict, query_remote_on_cache_miss: bool = False
) -> JsonDict:
logger.info("Handling query for keys %r", query)
store_queries = []
@ -232,7 +245,7 @@ class RemoteKey(DirectServeJsonResource):
for server_name, keys in cache_misses.items()
),
)
await self.query_keys(request, query, query_remote_on_cache_miss=False)
return await self.query_keys(query, query_remote_on_cache_miss=False)
else:
signed_keys = []
for key_json_raw in json_results:
@ -244,6 +257,4 @@ class RemoteKey(DirectServeJsonResource):
signed_keys.append(key_json)
response = {"server_keys": signed_keys}
respond_with_json(request, 200, response, canonical_json=True)
return {"server_keys": signed_keys}