annotate tests.server.FakeChannel (#13136)

This commit is contained in:
David Robertson 2022-07-04 18:08:56 +01:00 committed by GitHub
parent 5b5c943e7d
commit d102ad67fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 36 additions and 26 deletions

1
changelog.d/13136.misc Normal file
View File

@ -0,0 +1 @@
Add type annotations to `tests.server`.

View File

@ -1579,8 +1579,8 @@ class RoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id")) self.assertEqual(room_id, channel.json_body["rooms"][0].get("room_id"))
self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name")) self.assertEqual("ж", channel.json_body["rooms"][0].get("name"))
def test_single_room(self) -> None: def test_single_room(self) -> None:
"""Test that a single room can be requested correctly""" """Test that a single room can be requested correctly"""

View File

@ -1488,7 +1488,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
if channel.code != HTTPStatus.OK: if channel.code != HTTPStatus.OK:
raise HttpResponseException( raise HttpResponseException(
channel.code, channel.result["reason"], channel.json_body channel.code, channel.result["reason"], channel.result["body"]
) )
# Set monthly active users to the limit # Set monthly active users to the limit

View File

@ -949,7 +949,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
client_secret: str, client_secret: str,
next_link: Optional[str] = None, next_link: Optional[str] = None,
expect_code: int = 200, expect_code: int = 200,
) -> str: ) -> Optional[str]:
"""Request a validation token to add an email address to a user's account """Request a validation token to add an email address to a user's account
Args: Args:
@ -959,7 +959,8 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
expect_code: Expected return code of the call expect_code: Expected return code of the call
Returns: Returns:
The ID of the new threepid validation session The ID of the new threepid validation session, or None if the response
did not contain a session ID.
""" """
body = {"client_secret": client_secret, "email": email, "send_attempt": 1} body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
if next_link: if next_link:

View File

@ -153,18 +153,22 @@ class ProfileTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, 400, channel.result)
def _get_displayname(self, name: Optional[str] = None) -> str: def _get_displayname(self, name: Optional[str] = None) -> Optional[str]:
channel = self.make_request( channel = self.make_request(
"GET", "/profile/%s/displayname" % (name or self.owner,) "GET", "/profile/%s/displayname" % (name or self.owner,)
) )
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["displayname"] # FIXME: If a user has no displayname set, Synapse returns 200 and omits a
# displayname from the response. This contradicts the spec, see #13137.
return channel.json_body.get("displayname")
def _get_avatar_url(self, name: Optional[str] = None) -> str: def _get_avatar_url(self, name: Optional[str] = None) -> Optional[str]:
channel = self.make_request( channel = self.make_request(
"GET", "/profile/%s/avatar_url" % (name or self.owner,) "GET", "/profile/%s/avatar_url" % (name or self.owner,)
) )
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
# FIXME: If a user has no avatar set, Synapse returns 200 and omits an
# avatar_url from the response. This contradicts the spec, see #13137.
return channel.json_body.get("avatar_url") return channel.json_body.get("avatar_url")
@unittest.override_config({"max_avatar_size": 50}) @unittest.override_config({"max_avatar_size": 50})

View File

@ -800,7 +800,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
) )
expected_event_ids.append(channel.json_body["event_id"]) expected_event_ids.append(channel.json_body["event_id"])
prev_token = "" prev_token: Optional[str] = ""
found_event_ids: List[str] = [] found_event_ids: List[str] = []
for _ in range(20): for _ in range(20):
from_token = "" from_token = ""

View File

@ -43,6 +43,7 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import ( from twisted.internet.interfaces import (
IAddress, IAddress,
IConsumer,
IHostnameResolver, IHostnameResolver,
IProtocol, IProtocol,
IPullProducer, IPullProducer,
@ -53,11 +54,7 @@ from twisted.internet.interfaces import (
ITransport, ITransport,
) )
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.test.proto_helpers import ( from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
AccumulatingProtocol,
MemoryReactor,
MemoryReactorClock,
)
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.resource import IResource from twisted.web.resource import IResource
from twisted.web.server import Request, Site from twisted.web.server import Request, Site
@ -96,6 +93,7 @@ class TimedOutException(Exception):
""" """
@implementer(IConsumer)
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class FakeChannel: class FakeChannel:
""" """
@ -104,7 +102,7 @@ class FakeChannel:
""" """
site: Union[Site, "FakeSite"] site: Union[Site, "FakeSite"]
_reactor: MemoryReactor _reactor: MemoryReactorClock
result: dict = attr.Factory(dict) result: dict = attr.Factory(dict)
_ip: str = "127.0.0.1" _ip: str = "127.0.0.1"
_producer: Optional[Union[IPullProducer, IPushProducer]] = None _producer: Optional[Union[IPullProducer, IPushProducer]] = None
@ -122,7 +120,7 @@ class FakeChannel:
self._request = request self._request = request
@property @property
def json_body(self): def json_body(self) -> JsonDict:
return json.loads(self.text_body) return json.loads(self.text_body)
@property @property
@ -140,7 +138,7 @@ class FakeChannel:
return self.result.get("done", False) return self.result.get("done", False)
@property @property
def code(self): def code(self) -> int:
if not self.result: if not self.result:
raise Exception("No result yet.") raise Exception("No result yet.")
return int(self.result["code"]) return int(self.result["code"])
@ -160,7 +158,7 @@ class FakeChannel:
self.result["reason"] = reason self.result["reason"] = reason
self.result["headers"] = headers self.result["headers"] = headers
def write(self, content): def write(self, content: bytes) -> None:
assert isinstance(content, bytes), "Should be bytes! " + repr(content) assert isinstance(content, bytes), "Should be bytes! " + repr(content)
if "body" not in self.result: if "body" not in self.result:
@ -168,11 +166,16 @@ class FakeChannel:
self.result["body"] += content self.result["body"] += content
def registerProducer(self, producer, streaming): # Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
def registerProducer( # type: ignore[override]
self,
producer: Union[IPullProducer, IPushProducer],
streaming: bool,
) -> None:
self._producer = producer self._producer = producer
self.producerStreaming = streaming self.producerStreaming = streaming
def _produce(): def _produce() -> None:
if self._producer: if self._producer:
self._producer.resumeProducing() self._producer.resumeProducing()
self._reactor.callLater(0.1, _produce) self._reactor.callLater(0.1, _produce)
@ -180,31 +183,32 @@ class FakeChannel:
if not streaming: if not streaming:
self._reactor.callLater(0.0, _produce) self._reactor.callLater(0.0, _produce)
def unregisterProducer(self): def unregisterProducer(self) -> None:
if self._producer is None: if self._producer is None:
return return
self._producer = None self._producer = None
def requestDone(self, _self): def requestDone(self, _self: Request) -> None:
self.result["done"] = True self.result["done"] = True
if isinstance(_self, SynapseRequest): if isinstance(_self, SynapseRequest):
assert _self.logcontext is not None
self.resource_usage = _self.logcontext.get_resource_usage() self.resource_usage = _self.logcontext.get_resource_usage()
def getPeer(self): def getPeer(self) -> IAddress:
# We give an address so that getClientAddress/getClientIP returns a non null entry, # We give an address so that getClientAddress/getClientIP returns a non null entry,
# causing us to record the MAU # causing us to record the MAU
return address.IPv4Address("TCP", self._ip, 3423) return address.IPv4Address("TCP", self._ip, 3423)
def getHost(self): def getHost(self) -> IAddress:
# this is called by Request.__init__ to configure Request.host. # this is called by Request.__init__ to configure Request.host.
return address.IPv4Address("TCP", "127.0.0.1", 8888) return address.IPv4Address("TCP", "127.0.0.1", 8888)
def isSecure(self): def isSecure(self) -> bool:
return False return False
@property @property
def transport(self): def transport(self) -> "FakeChannel":
return self return self
def await_result(self, timeout_ms: int = 1000) -> None: def await_result(self, timeout_ms: int = 1000) -> None: