mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-02 12:16:09 -04:00
Add type hints for tests/unittest.py
. (#12347)
In particular, add type hints for get_success and friends, which are then helpful in a bunch of places.
This commit is contained in:
parent
33ebee47e4
commit
f0b03186d9
12 changed files with 97 additions and 48 deletions
|
@ -22,10 +22,11 @@ import secrets
|
|||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
AnyStr,
|
||||
Awaitable,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
|
@ -39,6 +40,7 @@ from unittest.mock import Mock, patch
|
|||
import canonicaljson
|
||||
import signedjson.key
|
||||
import unpaddedbase64
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from twisted.internet.defer import Deferred, ensureDeferred
|
||||
from twisted.python.failure import Failure
|
||||
|
@ -49,7 +51,7 @@ from twisted.web.resource import Resource
|
|||
from twisted.web.server import Request
|
||||
|
||||
from synapse import events
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.server import DEFAULT_ROOM_VERSION
|
||||
|
@ -70,7 +72,13 @@ from synapse.types import JsonDict, UserID, create_requester
|
|||
from synapse.util import Clock
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
|
||||
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
|
||||
from tests.server import (
|
||||
CustomHeaderType,
|
||||
FakeChannel,
|
||||
get_clock,
|
||||
make_request,
|
||||
setup_test_homeserver,
|
||||
)
|
||||
from tests.test_utils import event_injection, setup_awaitable_errors
|
||||
from tests.test_utils.logging_setup import setup_logging
|
||||
from tests.utils import default_config, setupdb
|
||||
|
@ -78,6 +86,17 @@ from tests.utils import default_config, setupdb
|
|||
setupdb()
|
||||
setup_logging()
|
||||
|
||||
TV = TypeVar("TV")
|
||||
_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
|
||||
|
||||
|
||||
class _TypedFailure(Generic[_ExcType], Protocol):
|
||||
"""Extension to twisted.Failure, where the 'value' has a certain type."""
|
||||
|
||||
@property
|
||||
def value(self) -> _ExcType:
|
||||
...
|
||||
|
||||
|
||||
def around(target):
|
||||
"""A CLOS-style 'around' modifier, which wraps the original method of the
|
||||
|
@ -276,6 +295,7 @@ class HomeserverTestCase(TestCase):
|
|||
|
||||
if hasattr(self, "user_id"):
|
||||
if self.hijack_auth:
|
||||
assert self.helper.auth_user_id is not None
|
||||
|
||||
# We need a valid token ID to satisfy foreign key constraints.
|
||||
token_id = self.get_success(
|
||||
|
@ -288,6 +308,7 @@ class HomeserverTestCase(TestCase):
|
|||
)
|
||||
|
||||
async def get_user_by_access_token(token=None, allow_guest=False):
|
||||
assert self.helper.auth_user_id is not None
|
||||
return {
|
||||
"user": UserID.from_string(self.helper.auth_user_id),
|
||||
"token_id": token_id,
|
||||
|
@ -295,6 +316,7 @@ class HomeserverTestCase(TestCase):
|
|||
}
|
||||
|
||||
async def get_user_by_req(request, allow_guest=False, rights="access"):
|
||||
assert self.helper.auth_user_id is not None
|
||||
return create_requester(
|
||||
UserID.from_string(self.helper.auth_user_id),
|
||||
token_id,
|
||||
|
@ -311,7 +333,7 @@ class HomeserverTestCase(TestCase):
|
|||
)
|
||||
|
||||
if self.needs_threadpool:
|
||||
self.reactor.threadpool = ThreadPool()
|
||||
self.reactor.threadpool = ThreadPool() # type: ignore[assignment]
|
||||
self.addCleanup(self.reactor.threadpool.stop)
|
||||
self.reactor.threadpool.start()
|
||||
|
||||
|
@ -426,7 +448,7 @@ class HomeserverTestCase(TestCase):
|
|||
federation_auth_origin: Optional[bytes] = None,
|
||||
content_is_form: bool = False,
|
||||
await_result: bool = True,
|
||||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
|
||||
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
|
||||
client_ip: str = "127.0.0.1",
|
||||
) -> FakeChannel:
|
||||
"""
|
||||
|
@ -511,30 +533,36 @@ class HomeserverTestCase(TestCase):
|
|||
|
||||
return hs
|
||||
|
||||
def pump(self, by=0.0):
|
||||
def pump(self, by: float = 0.0) -> None:
|
||||
"""
|
||||
Pump the reactor enough that Deferreds will fire.
|
||||
"""
|
||||
self.reactor.pump([by] * 100)
|
||||
|
||||
def get_success(self, d, by=0.0):
|
||||
deferred: Deferred[TV] = ensureDeferred(d)
|
||||
def get_success(
|
||||
self,
|
||||
d: Awaitable[TV],
|
||||
by: float = 0.0,
|
||||
) -> TV:
|
||||
deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
|
||||
self.pump(by=by)
|
||||
return self.successResultOf(deferred)
|
||||
|
||||
def get_failure(self, d, exc):
|
||||
def get_failure(
|
||||
self, d: Awaitable[Any], exc: Type[_ExcType]
|
||||
) -> _TypedFailure[_ExcType]:
|
||||
"""
|
||||
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
|
||||
"""
|
||||
deferred: Deferred[Any] = ensureDeferred(d)
|
||||
deferred: Deferred[Any] = ensureDeferred(d) # type: ignore[arg-type]
|
||||
self.pump()
|
||||
return self.failureResultOf(deferred, exc)
|
||||
|
||||
def get_success_or_raise(self, d, by=0.0):
|
||||
def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV:
|
||||
"""Drive deferred to completion and return result or raise exception
|
||||
on failure.
|
||||
"""
|
||||
deferred: Deferred[TV] = ensureDeferred(d)
|
||||
deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
|
||||
|
||||
results: list = []
|
||||
deferred.addBoth(results.append)
|
||||
|
@ -642,11 +670,11 @@ class HomeserverTestCase(TestCase):
|
|||
|
||||
def login(
|
||||
self,
|
||||
username,
|
||||
password,
|
||||
device_id=None,
|
||||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
|
||||
):
|
||||
username: str,
|
||||
password: str,
|
||||
device_id: Optional[str] = None,
|
||||
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Log in a user, and get an access token. Requires the Login API be
|
||||
registered.
|
||||
|
@ -668,18 +696,22 @@ class HomeserverTestCase(TestCase):
|
|||
return access_token
|
||||
|
||||
def create_and_send_event(
|
||||
self, room_id, user, soft_failed=False, prev_event_ids=None
|
||||
):
|
||||
self,
|
||||
room_id: str,
|
||||
user: UserID,
|
||||
soft_failed: bool = False,
|
||||
prev_event_ids: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create and send an event.
|
||||
|
||||
Args:
|
||||
soft_failed (bool): Whether to create a soft failed event or not
|
||||
prev_event_ids (list[str]|None): Explicitly set the prev events,
|
||||
soft_failed: Whether to create a soft failed event or not
|
||||
prev_event_ids: Explicitly set the prev events,
|
||||
or if None just use the default
|
||||
|
||||
Returns:
|
||||
str: The new event's ID.
|
||||
The new event's ID.
|
||||
"""
|
||||
event_creator = self.hs.get_event_creation_handler()
|
||||
requester = create_requester(user)
|
||||
|
@ -706,7 +738,7 @@ class HomeserverTestCase(TestCase):
|
|||
|
||||
return event.event_id
|
||||
|
||||
def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
|
||||
def inject_room_member(self, room: str, user: str, membership: str) -> None:
|
||||
"""
|
||||
Inject a membership event into a room.
|
||||
|
||||
|
@ -766,7 +798,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
|
|||
path: str,
|
||||
content: Optional[JsonDict] = None,
|
||||
await_result: bool = True,
|
||||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
|
||||
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
|
||||
client_ip: str = "127.0.0.1",
|
||||
) -> FakeChannel:
|
||||
"""Make an inbound signed federation request to this server
|
||||
|
@ -799,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
|
|||
self.site,
|
||||
method=method,
|
||||
path=path,
|
||||
content=content,
|
||||
content=content or "",
|
||||
shorthand=False,
|
||||
await_result=await_result,
|
||||
custom_headers=custom_headers,
|
||||
|
@ -878,9 +910,6 @@ def override_config(extra_config):
|
|||
return decorator
|
||||
|
||||
|
||||
TV = TypeVar("TV")
|
||||
|
||||
|
||||
def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
|
||||
"""A test decorator which will skip the decorated test unless a condition is set
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue