Finish type hints for federation client HTTP code. (#15465)

This commit is contained in:
Patrick Cloke 2023-04-24 13:12:06 -04:00 committed by GitHub
parent 19141b9432
commit ea5c3ede4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 82 additions and 42 deletions

1
changelog.d/15465.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -33,12 +33,6 @@ exclude = (?x)
|synapse/storage/schema/ |synapse/storage/schema/
)$ )$
[mypy-synapse.federation.transport.client]
disallow_untyped_defs = False
[mypy-synapse.http.matrixfederationclient]
disallow_untyped_defs = False
[mypy-synapse.metrics._reactor_metrics] [mypy-synapse.metrics._reactor_metrics]
disallow_untyped_defs = False disallow_untyped_defs = False
# This module imports select.epoll. That exists on Linux, but doesn't on macOS. # This module imports select.epoll. That exists on Linux, but doesn't on macOS.

View File

@ -280,15 +280,11 @@ class FederationClient(FederationBase):
logger.debug("backfill transaction_data=%r", transaction_data) logger.debug("backfill transaction_data=%r", transaction_data)
if not isinstance(transaction_data, dict): if not isinstance(transaction_data, dict):
# TODO we probably want an exception type specific to federation raise InvalidResponseError("Backfill transaction_data is not a dict.")
# client validation.
raise TypeError("Backfill transaction_data is not a dict.")
transaction_data_pdus = transaction_data.get("pdus") transaction_data_pdus = transaction_data.get("pdus")
if not isinstance(transaction_data_pdus, list): if not isinstance(transaction_data_pdus, list):
# TODO we probably want an exception type specific to federation raise InvalidResponseError("transaction_data.pdus is not a list.")
# client validation.
raise TypeError("transaction_data.pdus is not a list.")
room_version = await self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)

View File

@ -16,6 +16,7 @@
import logging import logging
import urllib import urllib
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Callable, Callable,
Collection, Collection,
@ -42,18 +43,21 @@ from synapse.api.urls import (
) )
from synapse.events import EventBase, make_event_from_dict from synapse.events import EventBase, make_event_from_dict
from synapse.federation.units import Transaction from synapse.federation.units import Transaction
from synapse.http.matrixfederationclient import ByteParser from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser
from synapse.http.types import QueryParams from synapse.http.types import QueryParams
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import ExceptionBundle from synapse.util import ExceptionBundle
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TransportLayerClient: class TransportLayerClient:
"""Sends federation HTTP requests to other servers""" """Sends federation HTTP requests to other servers"""
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname self.server_name = hs.hostname
self.client = hs.get_federation_http_client() self.client = hs.get_federation_http_client()
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
@ -133,7 +137,7 @@ class TransportLayerClient:
async def backfill( async def backfill(
self, destination: str, room_id: str, event_tuples: Collection[str], limit: int self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
) -> Optional[JsonDict]: ) -> Optional[Union[JsonDict, list]]:
"""Requests `limit` previous PDUs in a given context before list of """Requests `limit` previous PDUs in a given context before list of
PDUs. PDUs.
@ -388,6 +392,7 @@ class TransportLayerClient:
# server was just having a momentary blip, the room will be out of # server was just having a momentary blip, the room will be out of
# sync. # sync.
ignore_backoff=True, ignore_backoff=True,
parser=LegacyJsonSendParser(),
) )
async def send_leave_v2( async def send_leave_v2(
@ -445,7 +450,11 @@ class TransportLayerClient:
path = _create_v1_path("/invite/%s/%s", room_id, event_id) path = _create_v1_path("/invite/%s/%s", room_id, event_id)
return await self.client.put_json( return await self.client.put_json(
destination=destination, path=path, data=content, ignore_backoff=True destination=destination,
path=path,
data=content,
ignore_backoff=True,
parser=LegacyJsonSendParser(),
) )
async def send_invite_v2( async def send_invite_v2(

View File

@ -17,7 +17,6 @@ import codecs
import logging import logging
import random import random
import sys import sys
import typing
import urllib.parse import urllib.parse
from http import HTTPStatus from http import HTTPStatus
from io import BytesIO, StringIO from io import BytesIO, StringIO
@ -30,9 +29,11 @@ from typing import (
Generic, Generic,
List, List,
Optional, Optional,
TextIO,
Tuple, Tuple,
TypeVar, TypeVar,
Union, Union,
cast,
overload, overload,
) )
@ -183,20 +184,61 @@ class MatrixFederationRequest:
return self.json return self.json
class JsonParser(ByteParser[Union[JsonDict, list]]): class _BaseJsonParser(ByteParser[T]):
"""A parser that buffers the response and tries to parse it as JSON.""" """A parser that buffers the response and tries to parse it as JSON."""
CONTENT_TYPE = "application/json" CONTENT_TYPE = "application/json"
def __init__(self) -> None: def __init__(
self, validator: Optional[Callable[[Optional[object]], bool]] = None
) -> None:
"""
Args:
validator: A callable which takes the parsed JSON value and returns
true if the value is valid.
"""
self._buffer = StringIO() self._buffer = StringIO()
self._binary_wrapper = BinaryIOWrapper(self._buffer) self._binary_wrapper = BinaryIOWrapper(self._buffer)
self._validator = validator
def write(self, data: bytes) -> int: def write(self, data: bytes) -> int:
return self._binary_wrapper.write(data) return self._binary_wrapper.write(data)
def finish(self) -> Union[JsonDict, list]: def finish(self) -> T:
return json_decoder.decode(self._buffer.getvalue()) result = json_decoder.decode(self._buffer.getvalue())
if self._validator is not None and not self._validator(result):
raise ValueError(
f"Received incorrect JSON value: {result.__class__.__name__}"
)
return result
class JsonParser(_BaseJsonParser[JsonDict]):
"""A parser that buffers the response and tries to parse it as a JSON object."""
def __init__(self) -> None:
super().__init__(self._validate)
@staticmethod
def _validate(v: Any) -> bool:
return isinstance(v, dict)
class LegacyJsonSendParser(_BaseJsonParser[Tuple[int, JsonDict]]):
"""Ensure the legacy responses of /send_join & /send_leave are correct."""
def __init__(self) -> None:
super().__init__(self._validate)
@staticmethod
def _validate(v: Any) -> bool:
# Match [integer, JSON dict]
return (
isinstance(v, list)
and len(v) == 2
and type(v[0]) == int
and isinstance(v[1], dict)
)
async def _handle_response( async def _handle_response(
@ -313,9 +355,7 @@ async def _handle_response(
class BinaryIOWrapper: class BinaryIOWrapper:
"""A wrapper for a TextIO which converts from bytes on the fly.""" """A wrapper for a TextIO which converts from bytes on the fly."""
def __init__( def __init__(self, file: TextIO, encoding: str = "utf-8", errors: str = "strict"):
self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
):
self.decoder = codecs.getincrementaldecoder(encoding)(errors) self.decoder = codecs.getincrementaldecoder(encoding)(errors)
self.file = file self.file = file
@ -793,7 +833,7 @@ class MatrixFederationHttpClient:
backoff_on_404: bool = False, backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False, try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None, parser: Literal[None] = None,
) -> Union[JsonDict, list]: ) -> JsonDict:
... ...
@overload @overload
@ -825,8 +865,8 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False, ignore_backoff: bool = False,
backoff_on_404: bool = False, backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False, try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser] = None, parser: Optional[ByteParser[T]] = None,
): ) -> Union[JsonDict, T]:
"""Sends the specified json data using PUT """Sends the specified json data using PUT
Args: Args:
@ -902,7 +942,7 @@ class MatrixFederationHttpClient:
_sec_timeout = self.default_timeout _sec_timeout = self.default_timeout
if parser is None: if parser is None:
parser = JsonParser() parser = cast(ByteParser[T], JsonParser())
body = await _handle_response( body = await _handle_response(
self.reactor, self.reactor,
@ -924,7 +964,7 @@ class MatrixFederationHttpClient:
timeout: Optional[int] = None, timeout: Optional[int] = None,
ignore_backoff: bool = False, ignore_backoff: bool = False,
args: Optional[QueryParams] = None, args: Optional[QueryParams] = None,
) -> Union[JsonDict, list]: ) -> JsonDict:
"""Sends the specified json data using POST """Sends the specified json data using POST
Args: Args:
@ -998,7 +1038,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False, ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False, try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None, parser: Literal[None] = None,
) -> Union[JsonDict, list]: ) -> JsonDict:
... ...
@overload @overload
@ -1024,8 +1064,8 @@ class MatrixFederationHttpClient:
timeout: Optional[int] = None, timeout: Optional[int] = None,
ignore_backoff: bool = False, ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False, try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser] = None, parser: Optional[ByteParser[T]] = None,
): ) -> Union[JsonDict, T]:
"""GETs some json from the given host homeserver and path """GETs some json from the given host homeserver and path
Args: Args:
@ -1091,7 +1131,7 @@ class MatrixFederationHttpClient:
_sec_timeout = self.default_timeout _sec_timeout = self.default_timeout
if parser is None: if parser is None:
parser = JsonParser() parser = cast(ByteParser[T], JsonParser())
body = await _handle_response( body = await _handle_response(
self.reactor, self.reactor,
@ -1112,7 +1152,7 @@ class MatrixFederationHttpClient:
timeout: Optional[int] = None, timeout: Optional[int] = None,
ignore_backoff: bool = False, ignore_backoff: bool = False,
args: Optional[QueryParams] = None, args: Optional[QueryParams] = None,
) -> Union[JsonDict, list]: ) -> JsonDict:
"""Send a DELETE request to the remote expecting some json response """Send a DELETE request to the remote expecting some json response
Args: Args:

View File

@ -75,7 +75,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client() fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1)) return_value=make_awaitable(("", 1))
) )
@ -106,7 +106,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client() fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1)) return_value=make_awaitable(("", 1))
) )
@ -143,7 +143,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client() fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1)) return_value=make_awaitable(("", 1))
) )
@ -200,7 +200,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client() fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1)) return_value=make_awaitable(("", 1))
) )
@ -230,7 +230,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client() fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1)) return_value=make_awaitable(("", 1))
) )

View File

@ -26,7 +26,7 @@ from twisted.web.http import HTTPChannel
from synapse.api.errors import RequestSendFailed from synapse.api.errors import RequestSendFailed
from synapse.http.matrixfederationclient import ( from synapse.http.matrixfederationclient import (
JsonParser, ByteParser,
MatrixFederationHttpClient, MatrixFederationHttpClient,
MatrixFederationRequest, MatrixFederationRequest,
) )
@ -618,9 +618,9 @@ class FederationClientTests(HomeserverTestCase):
while not test_d.called: while not test_d.called:
protocol.dataReceived(b"a" * chunk_size) protocol.dataReceived(b"a" * chunk_size)
sent += chunk_size sent += chunk_size
self.assertLessEqual(sent, JsonParser.MAX_RESPONSE_SIZE) self.assertLessEqual(sent, ByteParser.MAX_RESPONSE_SIZE)
self.assertEqual(sent, JsonParser.MAX_RESPONSE_SIZE) self.assertEqual(sent, ByteParser.MAX_RESPONSE_SIZE)
f = self.failureResultOf(test_d) f = self.failureResultOf(test_d)
self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value, RequestSendFailed)