Type tests.utils (#13028)

* Cast to postgres types when handling postgres db

* Remove unused method

* Easy annotations

* Annotate create_room

* Use `ParamSpec` to annotate looping_call

* Annotate `default_config`

* Track `now` as a float

`time_ms` returns an int like the proper Synapse `Clock`

* Introduce a `Timer` dataclass

* Introduce a Looper type

* Suppress checking of a mock

* tests.utils is typed

* Changelog

* Whoops, import ParamSpec from typing_extensions

* ditch the psycopg2 casts
This commit is contained in:
David Robertson 2022-07-05 15:13:47 +01:00 committed by GitHub
parent 68695d8007
commit 6ba732fefe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 101 additions and 45 deletions

1
changelog.d/13028.misc Normal file
View File

@ -0,0 +1 @@
Add type annotations to `tests.utils`.

View File

@ -126,6 +126,9 @@ disallow_untyped_defs = True
[mypy-tests.federation.transport.test_client] [mypy-tests.federation.transport.test_client]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.utils]
disallow_untyped_defs = True
;; Dependencies without annotations ;; Dependencies without annotations
;; Before ignoring a module, check to see if type stubs are available. ;; Before ignoring a module, check to see if type stubs are available.

View File

@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, Generator, Optional
import attr import attr
from frozendict import frozendict from frozendict import frozendict
from matrix_common.versionstring import get_distribution_version_string from matrix_common.versionstring import get_distribution_version_string
from typing_extensions import ParamSpec
from twisted.internet import defer, task from twisted.internet import defer, task
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
@ -82,6 +83,9 @@ def unwrapFirstError(failure: Failure) -> Failure:
return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations
P = ParamSpec("P")
@attr.s(slots=True) @attr.s(slots=True)
class Clock: class Clock:
""" """
@ -110,7 +114,7 @@ class Clock:
return int(self.time() * 1000) return int(self.time() * 1000)
def looping_call( def looping_call(
self, f: Callable, msec: float, *args: Any, **kwargs: Any self, f: Callable[P, object], msec: float, *args: P.args, **kwargs: P.kwargs
) -> LoopingCall: ) -> LoopingCall:
"""Call a function repeatedly. """Call a function repeatedly.

View File

@ -109,7 +109,7 @@ GLOBAL_ROOT = ListNode["_Node"].create_root_node()
@wrap_as_background_process("LruCache._expire_old_entries") @wrap_as_background_process("LruCache._expire_old_entries")
async def _expire_old_entries( async def _expire_old_entries(
clock: Clock, expiry_seconds: int, autotune_config: Optional[dict] clock: Clock, expiry_seconds: float, autotune_config: Optional[dict]
) -> None: ) -> None:
"""Walks the global cache list to find cache entries that haven't been """Walks the global cache list to find cache entries that haven't been
accessed in the given number of seconds, or if a given memory threshold has been breached. accessed in the given number of seconds, or if a given memory threshold has been breached.

View File

@ -15,12 +15,17 @@
import atexit import atexit
import os import os
from typing import Any, Callable, Dict, List, Tuple, Union, overload
import attr
from typing_extensions import Literal, ParamSpec
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.logging.context import current_context, set_current_context from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
@ -50,12 +55,11 @@ SQLITE_PERSIST_DB = os.environ.get("SYNAPSE_TEST_PERSIST_SQLITE_DB") is not None
POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres" POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"
def setupdb(): def setupdb() -> None:
# If we're using PostgreSQL, set up the db once # If we're using PostgreSQL, set up the db once
if USE_POSTGRES_FOR_TESTS: if USE_POSTGRES_FOR_TESTS:
# create a PostgresEngine # create a PostgresEngine
db_engine = create_engine({"name": "psycopg2", "args": {}}) db_engine = create_engine({"name": "psycopg2", "args": {}})
# connect to postgres to create the base database. # connect to postgres to create the base database.
db_conn = db_engine.module.connect( db_conn = db_engine.module.connect(
user=POSTGRES_USER, user=POSTGRES_USER,
@ -82,11 +86,11 @@ def setupdb():
port=POSTGRES_PORT, port=POSTGRES_PORT,
password=POSTGRES_PASSWORD, password=POSTGRES_PASSWORD,
) )
db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
prepare_database(db_conn, db_engine, None) prepare_database(logging_conn, db_engine, None)
db_conn.close() logging_conn.close()
def _cleanup(): def _cleanup() -> None:
db_conn = db_engine.module.connect( db_conn = db_engine.module.connect(
user=POSTGRES_USER, user=POSTGRES_USER,
host=POSTGRES_HOST, host=POSTGRES_HOST,
@ -103,7 +107,19 @@ def setupdb():
atexit.register(_cleanup) atexit.register(_cleanup)
def default_config(name, parse=False): @overload
def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]:
...
@overload
def default_config(name: str, parse: Literal[True]) -> HomeServerConfig:
...
def default_config(
name: str, parse: bool = False
) -> Union[Dict[str, object], HomeServerConfig]:
""" """
Create a reasonable test config. Create a reasonable test config.
""" """
@ -181,90 +197,122 @@ def default_config(name, parse=False):
return config_dict return config_dict
def mock_getRawHeaders(headers=None): def mock_getRawHeaders(headers=None): # type: ignore[no-untyped-def]
headers = headers if headers is not None else {} headers = headers if headers is not None else {}
def getRawHeaders(name, default=None): def getRawHeaders(name, default=None): # type: ignore[no-untyped-def]
# If the requested header is present, the real twisted function returns
# List[str] if name is a str and List[bytes] if name is a bytes.
# This mock doesn't support that behaviour.
# Fortunately, none of the current callers of mock_getRawHeaders() provide a
# headers dict, so we don't encounter this discrepancy in practice.
return headers.get(name, default) return headers.get(name, default)
return getRawHeaders return getRawHeaders
P = ParamSpec("P")
@attr.s(slots=True, auto_attribs=True)
class Timer:
absolute_time: float
callback: Callable[[], None]
expired: bool
# TODO: Make this generic over a ParamSpec?
@attr.s(slots=True, auto_attribs=True)
class Looper:
func: Callable[..., Any]
interval: float # seconds
last: float
args: Tuple[object, ...]
kwargs: Dict[str, object]
class MockClock: class MockClock:
now = 1000 now = 1000.0
def __init__(self): def __init__(self) -> None:
# list of lists of [absolute_time, callback, expired] in no particular # Timers in no particular order
# order self.timers: List[Timer] = []
self.timers = [] self.loopers: List[Looper] = []
self.loopers = []
def time(self): def time(self) -> float:
return self.now return self.now
def time_msec(self): def time_msec(self) -> int:
return self.time() * 1000 return int(self.time() * 1000)
def call_later(self, delay, callback, *args, **kwargs): def call_later(
self,
delay: float,
callback: Callable[P, object],
*args: P.args,
**kwargs: P.kwargs,
) -> Timer:
ctx = current_context() ctx = current_context()
def wrapped_callback(): def wrapped_callback() -> None:
set_current_context(ctx) set_current_context(ctx)
callback(*args, **kwargs) callback(*args, **kwargs)
t = [self.now + delay, wrapped_callback, False] t = Timer(self.now + delay, wrapped_callback, False)
self.timers.append(t) self.timers.append(t)
return t return t
def looping_call(self, function, interval, *args, **kwargs): def looping_call(
self.loopers.append([function, interval / 1000.0, self.now, args, kwargs]) self,
function: Callable[P, object],
interval: float,
*args: P.args,
**kwargs: P.kwargs,
) -> None:
# This type-ignore should be redundant once we use a mypy release with
# https://github.com/python/mypy/pull/12668.
self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) # type: ignore[arg-type]
def cancel_call_later(self, timer, ignore_errs=False): def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None:
if timer[2]: if timer.expired:
if not ignore_errs: if not ignore_errs:
raise Exception("Cannot cancel an expired timer") raise Exception("Cannot cancel an expired timer")
timer[2] = True timer.expired = True
self.timers = [t for t in self.timers if t != timer] self.timers = [t for t in self.timers if t != timer]
# For unit testing # For unit testing
def advance_time(self, secs): def advance_time(self, secs: float) -> None:
self.now += secs self.now += secs
timers = self.timers timers = self.timers
self.timers = [] self.timers = []
for t in timers: for t in timers:
time, callback, expired = t if t.expired:
if expired:
raise Exception("Timer already expired") raise Exception("Timer already expired")
if self.now >= time: if self.now >= t.absolute_time:
t[2] = True t.expired = True
callback() t.callback()
else: else:
self.timers.append(t) self.timers.append(t)
for looped in self.loopers: for looped in self.loopers:
func, interval, last, args, kwargs = looped if looped.last + looped.interval < self.now:
if last + interval < self.now: looped.func(*looped.args, **looped.kwargs)
func(*args, **kwargs) looped.last = self.now
looped[2] = self.now
def advance_time_msec(self, ms): def advance_time_msec(self, ms: float) -> None:
self.advance_time(ms / 1000.0) self.advance_time(ms / 1000.0)
def time_bound_deferred(self, d, *args, **kwargs):
# We don't bother timing things out for now.
return d
async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
async def create_room(hs, room_id: str, creator_id: str):
"""Creates and persist a creation event for the given room""" """Creates and persist a creation event for the given room"""
persistence_store = hs.get_storage_controllers().persistence persistence_store = hs.get_storage_controllers().persistence
assert persistence_store is not None
store = hs.get_datastores().main store = hs.get_datastores().main
event_builder_factory = hs.get_event_builder_factory() event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler() event_creation_handler = hs.get_event_creation_handler()