mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Finish type hints for federation client HTTP code. (#15465)
This commit is contained in:
parent
19141b9432
commit
ea5c3ede4f
1
changelog.d/15465.misc
Normal file
1
changelog.d/15465.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Improve type hints.
|
6
mypy.ini
6
mypy.ini
@ -33,12 +33,6 @@ exclude = (?x)
|
|||||||
|synapse/storage/schema/
|
|synapse/storage/schema/
|
||||||
)$
|
)$
|
||||||
|
|
||||||
[mypy-synapse.federation.transport.client]
|
|
||||||
disallow_untyped_defs = False
|
|
||||||
|
|
||||||
[mypy-synapse.http.matrixfederationclient]
|
|
||||||
disallow_untyped_defs = False
|
|
||||||
|
|
||||||
[mypy-synapse.metrics._reactor_metrics]
|
[mypy-synapse.metrics._reactor_metrics]
|
||||||
disallow_untyped_defs = False
|
disallow_untyped_defs = False
|
||||||
# This module imports select.epoll. That exists on Linux, but doesn't on macOS.
|
# This module imports select.epoll. That exists on Linux, but doesn't on macOS.
|
||||||
|
@ -280,15 +280,11 @@ class FederationClient(FederationBase):
|
|||||||
logger.debug("backfill transaction_data=%r", transaction_data)
|
logger.debug("backfill transaction_data=%r", transaction_data)
|
||||||
|
|
||||||
if not isinstance(transaction_data, dict):
|
if not isinstance(transaction_data, dict):
|
||||||
# TODO we probably want an exception type specific to federation
|
raise InvalidResponseError("Backfill transaction_data is not a dict.")
|
||||||
# client validation.
|
|
||||||
raise TypeError("Backfill transaction_data is not a dict.")
|
|
||||||
|
|
||||||
transaction_data_pdus = transaction_data.get("pdus")
|
transaction_data_pdus = transaction_data.get("pdus")
|
||||||
if not isinstance(transaction_data_pdus, list):
|
if not isinstance(transaction_data_pdus, list):
|
||||||
# TODO we probably want an exception type specific to federation
|
raise InvalidResponseError("transaction_data.pdus is not a list.")
|
||||||
# client validation.
|
|
||||||
raise TypeError("transaction_data.pdus is not a list.")
|
|
||||||
|
|
||||||
room_version = await self.store.get_room_version(room_id)
|
room_version = await self.store.get_room_version(room_id)
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
@ -42,18 +43,21 @@ from synapse.api.urls import (
|
|||||||
)
|
)
|
||||||
from synapse.events import EventBase, make_event_from_dict
|
from synapse.events import EventBase, make_event_from_dict
|
||||||
from synapse.federation.units import Transaction
|
from synapse.federation.units import Transaction
|
||||||
from synapse.http.matrixfederationclient import ByteParser
|
from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser
|
||||||
from synapse.http.types import QueryParams
|
from synapse.http.types import QueryParams
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import ExceptionBundle
|
from synapse.util import ExceptionBundle
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TransportLayerClient:
|
class TransportLayerClient:
|
||||||
"""Sends federation HTTP requests to other servers"""
|
"""Sends federation HTTP requests to other servers"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.client = hs.get_federation_http_client()
|
self.client = hs.get_federation_http_client()
|
||||||
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
|
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
|
||||||
@ -133,7 +137,7 @@ class TransportLayerClient:
|
|||||||
|
|
||||||
async def backfill(
|
async def backfill(
|
||||||
self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
|
self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
|
||||||
) -> Optional[JsonDict]:
|
) -> Optional[Union[JsonDict, list]]:
|
||||||
"""Requests `limit` previous PDUs in a given context before list of
|
"""Requests `limit` previous PDUs in a given context before list of
|
||||||
PDUs.
|
PDUs.
|
||||||
|
|
||||||
@ -388,6 +392,7 @@ class TransportLayerClient:
|
|||||||
# server was just having a momentary blip, the room will be out of
|
# server was just having a momentary blip, the room will be out of
|
||||||
# sync.
|
# sync.
|
||||||
ignore_backoff=True,
|
ignore_backoff=True,
|
||||||
|
parser=LegacyJsonSendParser(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_leave_v2(
|
async def send_leave_v2(
|
||||||
@ -445,7 +450,11 @@ class TransportLayerClient:
|
|||||||
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
|
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
|
||||||
|
|
||||||
return await self.client.put_json(
|
return await self.client.put_json(
|
||||||
destination=destination, path=path, data=content, ignore_backoff=True
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
data=content,
|
||||||
|
ignore_backoff=True,
|
||||||
|
parser=LegacyJsonSendParser(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_invite_v2(
|
async def send_invite_v2(
|
||||||
|
@ -17,7 +17,6 @@ import codecs
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
import typing
|
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from io import BytesIO, StringIO
|
from io import BytesIO, StringIO
|
||||||
@ -30,9 +29,11 @@ from typing import (
|
|||||||
Generic,
|
Generic,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
TextIO,
|
||||||
Tuple,
|
Tuple,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -183,20 +184,61 @@ class MatrixFederationRequest:
|
|||||||
return self.json
|
return self.json
|
||||||
|
|
||||||
|
|
||||||
class JsonParser(ByteParser[Union[JsonDict, list]]):
|
class _BaseJsonParser(ByteParser[T]):
|
||||||
"""A parser that buffers the response and tries to parse it as JSON."""
|
"""A parser that buffers the response and tries to parse it as JSON."""
|
||||||
|
|
||||||
CONTENT_TYPE = "application/json"
|
CONTENT_TYPE = "application/json"
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(
|
||||||
|
self, validator: Optional[Callable[[Optional[object]], bool]] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
validator: A callable which takes the parsed JSON value and returns
|
||||||
|
true if the value is valid.
|
||||||
|
"""
|
||||||
self._buffer = StringIO()
|
self._buffer = StringIO()
|
||||||
self._binary_wrapper = BinaryIOWrapper(self._buffer)
|
self._binary_wrapper = BinaryIOWrapper(self._buffer)
|
||||||
|
self._validator = validator
|
||||||
|
|
||||||
def write(self, data: bytes) -> int:
|
def write(self, data: bytes) -> int:
|
||||||
return self._binary_wrapper.write(data)
|
return self._binary_wrapper.write(data)
|
||||||
|
|
||||||
def finish(self) -> Union[JsonDict, list]:
|
def finish(self) -> T:
|
||||||
return json_decoder.decode(self._buffer.getvalue())
|
result = json_decoder.decode(self._buffer.getvalue())
|
||||||
|
if self._validator is not None and not self._validator(result):
|
||||||
|
raise ValueError(
|
||||||
|
f"Received incorrect JSON value: {result.__class__.__name__}"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class JsonParser(_BaseJsonParser[JsonDict]):
|
||||||
|
"""A parser that buffers the response and tries to parse it as a JSON object."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(self._validate)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate(v: Any) -> bool:
|
||||||
|
return isinstance(v, dict)
|
||||||
|
|
||||||
|
|
||||||
|
class LegacyJsonSendParser(_BaseJsonParser[Tuple[int, JsonDict]]):
|
||||||
|
"""Ensure the legacy responses of /send_join & /send_leave are correct."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(self._validate)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate(v: Any) -> bool:
|
||||||
|
# Match [integer, JSON dict]
|
||||||
|
return (
|
||||||
|
isinstance(v, list)
|
||||||
|
and len(v) == 2
|
||||||
|
and type(v[0]) == int
|
||||||
|
and isinstance(v[1], dict)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _handle_response(
|
async def _handle_response(
|
||||||
@ -313,9 +355,7 @@ async def _handle_response(
|
|||||||
class BinaryIOWrapper:
|
class BinaryIOWrapper:
|
||||||
"""A wrapper for a TextIO which converts from bytes on the fly."""
|
"""A wrapper for a TextIO which converts from bytes on the fly."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, file: TextIO, encoding: str = "utf-8", errors: str = "strict"):
|
||||||
self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
|
|
||||||
):
|
|
||||||
self.decoder = codecs.getincrementaldecoder(encoding)(errors)
|
self.decoder = codecs.getincrementaldecoder(encoding)(errors)
|
||||||
self.file = file
|
self.file = file
|
||||||
|
|
||||||
@ -793,7 +833,7 @@ class MatrixFederationHttpClient:
|
|||||||
backoff_on_404: bool = False,
|
backoff_on_404: bool = False,
|
||||||
try_trailing_slash_on_400: bool = False,
|
try_trailing_slash_on_400: bool = False,
|
||||||
parser: Literal[None] = None,
|
parser: Literal[None] = None,
|
||||||
) -> Union[JsonDict, list]:
|
) -> JsonDict:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@ -825,8 +865,8 @@ class MatrixFederationHttpClient:
|
|||||||
ignore_backoff: bool = False,
|
ignore_backoff: bool = False,
|
||||||
backoff_on_404: bool = False,
|
backoff_on_404: bool = False,
|
||||||
try_trailing_slash_on_400: bool = False,
|
try_trailing_slash_on_400: bool = False,
|
||||||
parser: Optional[ByteParser] = None,
|
parser: Optional[ByteParser[T]] = None,
|
||||||
):
|
) -> Union[JsonDict, T]:
|
||||||
"""Sends the specified json data using PUT
|
"""Sends the specified json data using PUT
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -902,7 +942,7 @@ class MatrixFederationHttpClient:
|
|||||||
_sec_timeout = self.default_timeout
|
_sec_timeout = self.default_timeout
|
||||||
|
|
||||||
if parser is None:
|
if parser is None:
|
||||||
parser = JsonParser()
|
parser = cast(ByteParser[T], JsonParser())
|
||||||
|
|
||||||
body = await _handle_response(
|
body = await _handle_response(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
@ -924,7 +964,7 @@ class MatrixFederationHttpClient:
|
|||||||
timeout: Optional[int] = None,
|
timeout: Optional[int] = None,
|
||||||
ignore_backoff: bool = False,
|
ignore_backoff: bool = False,
|
||||||
args: Optional[QueryParams] = None,
|
args: Optional[QueryParams] = None,
|
||||||
) -> Union[JsonDict, list]:
|
) -> JsonDict:
|
||||||
"""Sends the specified json data using POST
|
"""Sends the specified json data using POST
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -998,7 +1038,7 @@ class MatrixFederationHttpClient:
|
|||||||
ignore_backoff: bool = False,
|
ignore_backoff: bool = False,
|
||||||
try_trailing_slash_on_400: bool = False,
|
try_trailing_slash_on_400: bool = False,
|
||||||
parser: Literal[None] = None,
|
parser: Literal[None] = None,
|
||||||
) -> Union[JsonDict, list]:
|
) -> JsonDict:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@ -1024,8 +1064,8 @@ class MatrixFederationHttpClient:
|
|||||||
timeout: Optional[int] = None,
|
timeout: Optional[int] = None,
|
||||||
ignore_backoff: bool = False,
|
ignore_backoff: bool = False,
|
||||||
try_trailing_slash_on_400: bool = False,
|
try_trailing_slash_on_400: bool = False,
|
||||||
parser: Optional[ByteParser] = None,
|
parser: Optional[ByteParser[T]] = None,
|
||||||
):
|
) -> Union[JsonDict, T]:
|
||||||
"""GETs some json from the given host homeserver and path
|
"""GETs some json from the given host homeserver and path
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1091,7 +1131,7 @@ class MatrixFederationHttpClient:
|
|||||||
_sec_timeout = self.default_timeout
|
_sec_timeout = self.default_timeout
|
||||||
|
|
||||||
if parser is None:
|
if parser is None:
|
||||||
parser = JsonParser()
|
parser = cast(ByteParser[T], JsonParser())
|
||||||
|
|
||||||
body = await _handle_response(
|
body = await _handle_response(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
@ -1112,7 +1152,7 @@ class MatrixFederationHttpClient:
|
|||||||
timeout: Optional[int] = None,
|
timeout: Optional[int] = None,
|
||||||
ignore_backoff: bool = False,
|
ignore_backoff: bool = False,
|
||||||
args: Optional[QueryParams] = None,
|
args: Optional[QueryParams] = None,
|
||||||
) -> Union[JsonDict, list]:
|
) -> JsonDict:
|
||||||
"""Send a DELETE request to the remote expecting some json response
|
"""Send a DELETE request to the remote expecting some json response
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -75,7 +75,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
|||||||
fed_transport = self.hs.get_federation_transport_client()
|
fed_transport = self.hs.get_federation_transport_client()
|
||||||
|
|
||||||
# Mock out some things, because we don't want to test the whole join
|
# Mock out some things, because we don't want to test the whole join
|
||||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
|
||||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(("", 1))
|
return_value=make_awaitable(("", 1))
|
||||||
)
|
)
|
||||||
@ -106,7 +106,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
|||||||
fed_transport = self.hs.get_federation_transport_client()
|
fed_transport = self.hs.get_federation_transport_client()
|
||||||
|
|
||||||
# Mock out some things, because we don't want to test the whole join
|
# Mock out some things, because we don't want to test the whole join
|
||||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
|
||||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(("", 1))
|
return_value=make_awaitable(("", 1))
|
||||||
)
|
)
|
||||||
@ -143,7 +143,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
|||||||
fed_transport = self.hs.get_federation_transport_client()
|
fed_transport = self.hs.get_federation_transport_client()
|
||||||
|
|
||||||
# Mock out some things, because we don't want to test the whole join
|
# Mock out some things, because we don't want to test the whole join
|
||||||
fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
|
fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
|
||||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(("", 1))
|
return_value=make_awaitable(("", 1))
|
||||||
)
|
)
|
||||||
@ -200,7 +200,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
|
|||||||
fed_transport = self.hs.get_federation_transport_client()
|
fed_transport = self.hs.get_federation_transport_client()
|
||||||
|
|
||||||
# Mock out some things, because we don't want to test the whole join
|
# Mock out some things, because we don't want to test the whole join
|
||||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
|
||||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(("", 1))
|
return_value=make_awaitable(("", 1))
|
||||||
)
|
)
|
||||||
@ -230,7 +230,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
|
|||||||
fed_transport = self.hs.get_federation_transport_client()
|
fed_transport = self.hs.get_federation_transport_client()
|
||||||
|
|
||||||
# Mock out some things, because we don't want to test the whole join
|
# Mock out some things, because we don't want to test the whole join
|
||||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
|
||||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(("", 1))
|
return_value=make_awaitable(("", 1))
|
||||||
)
|
)
|
||||||
|
@ -26,7 +26,7 @@ from twisted.web.http import HTTPChannel
|
|||||||
|
|
||||||
from synapse.api.errors import RequestSendFailed
|
from synapse.api.errors import RequestSendFailed
|
||||||
from synapse.http.matrixfederationclient import (
|
from synapse.http.matrixfederationclient import (
|
||||||
JsonParser,
|
ByteParser,
|
||||||
MatrixFederationHttpClient,
|
MatrixFederationHttpClient,
|
||||||
MatrixFederationRequest,
|
MatrixFederationRequest,
|
||||||
)
|
)
|
||||||
@ -618,9 +618,9 @@ class FederationClientTests(HomeserverTestCase):
|
|||||||
while not test_d.called:
|
while not test_d.called:
|
||||||
protocol.dataReceived(b"a" * chunk_size)
|
protocol.dataReceived(b"a" * chunk_size)
|
||||||
sent += chunk_size
|
sent += chunk_size
|
||||||
self.assertLessEqual(sent, JsonParser.MAX_RESPONSE_SIZE)
|
self.assertLessEqual(sent, ByteParser.MAX_RESPONSE_SIZE)
|
||||||
|
|
||||||
self.assertEqual(sent, JsonParser.MAX_RESPONSE_SIZE)
|
self.assertEqual(sent, ByteParser.MAX_RESPONSE_SIZE)
|
||||||
|
|
||||||
f = self.failureResultOf(test_d)
|
f = self.failureResultOf(test_d)
|
||||||
self.assertIsInstance(f.value, RequestSendFailed)
|
self.assertIsInstance(f.value, RequestSendFailed)
|
||||||
|
Loading…
Reference in New Issue
Block a user