mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-10-01 11:49:51 -04:00
Add missing types to tests.util. (#14597)
Removes files under tests.util from the ignored by list, then fully types all tests/util/*.py files.
This commit is contained in:
parent
fac8a38525
commit
acea4d7a2f
1
changelog.d/14597.misc
Normal file
1
changelog.d/14597.misc
Normal file
@ -0,0 +1 @@
|
||||
Add missing type hints.
|
13
mypy.ini
13
mypy.ini
@ -59,16 +59,6 @@ exclude = (?x)
|
||||
|tests/server_notices/test_resource_limits_server_notices.py
|
||||
|tests/test_state.py
|
||||
|tests/test_terms_auth.py
|
||||
|tests/util/test_async_helpers.py
|
||||
|tests/util/test_batching_queue.py
|
||||
|tests/util/test_dict_cache.py
|
||||
|tests/util/test_expiring_cache.py
|
||||
|tests/util/test_file_consumer.py
|
||||
|tests/util/test_linearizer.py
|
||||
|tests/util/test_logcontext.py
|
||||
|tests/util/test_lrucache.py
|
||||
|tests/util/test_rwlock.py
|
||||
|tests/util/test_wheel_timer.py
|
||||
)$
|
||||
|
||||
[mypy-synapse.federation.transport.client]
|
||||
@ -137,6 +127,9 @@ disallow_untyped_defs = True
|
||||
[mypy-tests.util.caches.test_descriptors]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-tests.util.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.utils]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import traceback
|
||||
from typing import Generator, List, NoReturn, Optional
|
||||
|
||||
from parameterized import parameterized_class
|
||||
|
||||
@ -41,8 +42,8 @@ from tests.unittest import TestCase
|
||||
|
||||
|
||||
class ObservableDeferredTest(TestCase):
|
||||
def test_succeed(self):
|
||||
origin_d = Deferred()
|
||||
def test_succeed(self) -> None:
|
||||
origin_d: "Deferred[int]" = Deferred()
|
||||
observable = ObservableDeferred(origin_d)
|
||||
|
||||
observer1 = observable.observe()
|
||||
@ -52,16 +53,18 @@ class ObservableDeferredTest(TestCase):
|
||||
self.assertFalse(observer2.called)
|
||||
|
||||
# check the first observer is called first
|
||||
def check_called_first(res):
|
||||
def check_called_first(res: int) -> int:
|
||||
self.assertFalse(observer2.called)
|
||||
return res
|
||||
|
||||
observer1.addBoth(check_called_first)
|
||||
|
||||
# store the results
|
||||
results = [None, None]
|
||||
results: List[Optional[ObservableDeferred[int]]] = [None, None]
|
||||
|
||||
def check_val(res, idx):
|
||||
def check_val(
|
||||
res: ObservableDeferred[int], idx: int
|
||||
) -> ObservableDeferred[int]:
|
||||
results[idx] = res
|
||||
return res
|
||||
|
||||
@ -72,8 +75,8 @@ class ObservableDeferredTest(TestCase):
|
||||
self.assertEqual(results[0], 123, "observer 1 callback result")
|
||||
self.assertEqual(results[1], 123, "observer 2 callback result")
|
||||
|
||||
def test_failure(self):
|
||||
origin_d = Deferred()
|
||||
def test_failure(self) -> None:
|
||||
origin_d: Deferred = Deferred()
|
||||
observable = ObservableDeferred(origin_d, consumeErrors=True)
|
||||
|
||||
observer1 = observable.observe()
|
||||
@ -83,16 +86,16 @@ class ObservableDeferredTest(TestCase):
|
||||
self.assertFalse(observer2.called)
|
||||
|
||||
# check the first observer is called first
|
||||
def check_called_first(res):
|
||||
def check_called_first(res: int) -> int:
|
||||
self.assertFalse(observer2.called)
|
||||
return res
|
||||
|
||||
observer1.addBoth(check_called_first)
|
||||
|
||||
# store the results
|
||||
results = [None, None]
|
||||
results: List[Optional[ObservableDeferred[str]]] = [None, None]
|
||||
|
||||
def check_val(res, idx):
|
||||
def check_val(res: ObservableDeferred[str], idx: int) -> None:
|
||||
results[idx] = res
|
||||
return None
|
||||
|
||||
@ -103,10 +106,12 @@ class ObservableDeferredTest(TestCase):
|
||||
raise Exception("gah!")
|
||||
except Exception as e:
|
||||
origin_d.errback(e)
|
||||
assert results[0] is not None
|
||||
self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
|
||||
assert results[1] is not None
|
||||
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
|
||||
|
||||
def test_cancellation(self):
|
||||
def test_cancellation(self) -> None:
|
||||
"""Test that cancelling an observer does not affect other observers."""
|
||||
origin_d: "Deferred[int]" = Deferred()
|
||||
observable = ObservableDeferred(origin_d, consumeErrors=True)
|
||||
@ -136,37 +141,38 @@ class ObservableDeferredTest(TestCase):
|
||||
|
||||
|
||||
class TimeoutDeferredTest(TestCase):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.clock = Clock()
|
||||
|
||||
def test_times_out(self):
|
||||
def test_times_out(self) -> None:
|
||||
"""Basic test case that checks that the original deferred is cancelled and that
|
||||
the timing-out deferred is errbacked
|
||||
"""
|
||||
cancelled = [False]
|
||||
cancelled = False
|
||||
|
||||
def canceller(_d):
|
||||
cancelled[0] = True
|
||||
def canceller(_d: Deferred) -> None:
|
||||
nonlocal cancelled
|
||||
cancelled = True
|
||||
|
||||
non_completing_d = Deferred(canceller)
|
||||
non_completing_d: Deferred = Deferred(canceller)
|
||||
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
|
||||
|
||||
self.assertNoResult(timing_out_d)
|
||||
self.assertFalse(cancelled[0], "deferred was cancelled prematurely")
|
||||
self.assertFalse(cancelled, "deferred was cancelled prematurely")
|
||||
|
||||
self.clock.pump((1.0,))
|
||||
|
||||
self.assertTrue(cancelled[0], "deferred was not cancelled by timeout")
|
||||
self.assertTrue(cancelled, "deferred was not cancelled by timeout")
|
||||
self.failureResultOf(timing_out_d, defer.TimeoutError)
|
||||
|
||||
def test_times_out_when_canceller_throws(self):
|
||||
def test_times_out_when_canceller_throws(self) -> None:
|
||||
"""Test that we have successfully worked around
|
||||
https://twistedmatrix.com/trac/ticket/9534"""
|
||||
|
||||
def canceller(_d):
|
||||
def canceller(_d: Deferred) -> None:
|
||||
raise Exception("can't cancel this deferred")
|
||||
|
||||
non_completing_d = Deferred(canceller)
|
||||
non_completing_d: Deferred = Deferred(canceller)
|
||||
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
|
||||
|
||||
self.assertNoResult(timing_out_d)
|
||||
@ -175,22 +181,24 @@ class TimeoutDeferredTest(TestCase):
|
||||
|
||||
self.failureResultOf(timing_out_d, defer.TimeoutError)
|
||||
|
||||
def test_logcontext_is_preserved_on_cancellation(self):
|
||||
blocking_was_cancelled = [False]
|
||||
def test_logcontext_is_preserved_on_cancellation(self) -> None:
|
||||
blocking_was_cancelled = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def blocking():
|
||||
non_completing_d = Deferred()
|
||||
def blocking() -> Generator["Deferred[object]", object, None]:
|
||||
nonlocal blocking_was_cancelled
|
||||
|
||||
non_completing_d: Deferred = Deferred()
|
||||
with PreserveLoggingContext():
|
||||
try:
|
||||
yield non_completing_d
|
||||
except CancelledError:
|
||||
blocking_was_cancelled[0] = True
|
||||
blocking_was_cancelled = True
|
||||
raise
|
||||
|
||||
with LoggingContext("one") as context_one:
|
||||
# the errbacks should be run in the test logcontext
|
||||
def errback(res, deferred_name):
|
||||
def errback(res: Failure, deferred_name: str) -> Failure:
|
||||
self.assertIs(
|
||||
current_context(),
|
||||
context_one,
|
||||
@ -209,7 +217,7 @@ class TimeoutDeferredTest(TestCase):
|
||||
self.clock.pump((1.0,))
|
||||
|
||||
self.assertTrue(
|
||||
blocking_was_cancelled[0], "non-completing deferred was not cancelled"
|
||||
blocking_was_cancelled, "non-completing deferred was not cancelled"
|
||||
)
|
||||
self.failureResultOf(timing_out_d, defer.TimeoutError)
|
||||
self.assertIs(current_context(), context_one)
|
||||
@ -220,13 +228,13 @@ class _TestException(Exception):
|
||||
|
||||
|
||||
class ConcurrentlyExecuteTest(TestCase):
|
||||
def test_limits_runners(self):
|
||||
def test_limits_runners(self) -> None:
|
||||
"""If we have more tasks than runners, we should get the limit of runners"""
|
||||
started = 0
|
||||
waiters = []
|
||||
processed = []
|
||||
|
||||
async def callback(v):
|
||||
async def callback(v: int) -> None:
|
||||
# when we first enter, bump the start count
|
||||
nonlocal started
|
||||
started += 1
|
||||
@ -235,7 +243,7 @@ class ConcurrentlyExecuteTest(TestCase):
|
||||
processed.append(v)
|
||||
|
||||
# wait for the goahead before returning
|
||||
d2 = Deferred()
|
||||
d2: "Deferred[int]" = Deferred()
|
||||
waiters.append(d2)
|
||||
await d2
|
||||
|
||||
@ -265,16 +273,16 @@ class ConcurrentlyExecuteTest(TestCase):
|
||||
self.assertCountEqual(processed, [1, 2, 3, 4, 5])
|
||||
self.successResultOf(d2)
|
||||
|
||||
def test_preserves_stacktraces(self):
|
||||
def test_preserves_stacktraces(self) -> None:
|
||||
"""Test that the stacktrace from an exception thrown in the callback is preserved"""
|
||||
d1 = Deferred()
|
||||
d1: "Deferred[int]" = Deferred()
|
||||
|
||||
async def callback(v):
|
||||
async def callback(v: int) -> None:
|
||||
# alas, this doesn't work at all without an await here
|
||||
await d1
|
||||
raise _TestException("bah")
|
||||
|
||||
async def caller():
|
||||
async def caller() -> None:
|
||||
try:
|
||||
await concurrently_execute(callback, [1], 2)
|
||||
except _TestException as e:
|
||||
@ -290,17 +298,17 @@ class ConcurrentlyExecuteTest(TestCase):
|
||||
d1.callback(0)
|
||||
self.successResultOf(d2)
|
||||
|
||||
def test_preserves_stacktraces_on_preformed_failure(self):
|
||||
def test_preserves_stacktraces_on_preformed_failure(self) -> None:
|
||||
"""Test that the stacktrace on a Failure returned by the callback is preserved"""
|
||||
d1 = Deferred()
|
||||
d1: "Deferred[int]" = Deferred()
|
||||
f = Failure(_TestException("bah"))
|
||||
|
||||
async def callback(v):
|
||||
async def callback(v: int) -> None:
|
||||
# alas, this doesn't work at all without an await here
|
||||
await d1
|
||||
await defer.fail(f)
|
||||
|
||||
async def caller():
|
||||
async def caller() -> None:
|
||||
try:
|
||||
await concurrently_execute(callback, [1], 2)
|
||||
except _TestException as e:
|
||||
@ -336,7 +344,7 @@ class CancellationWrapperTests(TestCase):
|
||||
else:
|
||||
raise ValueError(f"Unsupported wrapper type: {self.wrapper}")
|
||||
|
||||
def test_succeed(self):
|
||||
def test_succeed(self) -> None:
|
||||
"""Test that the new `Deferred` receives the result."""
|
||||
deferred: "Deferred[str]" = Deferred()
|
||||
wrapper_deferred = self.wrap_deferred(deferred)
|
||||
@ -346,7 +354,7 @@ class CancellationWrapperTests(TestCase):
|
||||
self.assertTrue(wrapper_deferred.called)
|
||||
self.assertEqual("success", self.successResultOf(wrapper_deferred))
|
||||
|
||||
def test_failure(self):
|
||||
def test_failure(self) -> None:
|
||||
"""Test that the new `Deferred` receives the `Failure`."""
|
||||
deferred: "Deferred[str]" = Deferred()
|
||||
wrapper_deferred = self.wrap_deferred(deferred)
|
||||
@ -361,7 +369,7 @@ class CancellationWrapperTests(TestCase):
|
||||
class StopCancellationTests(TestCase):
|
||||
"""Tests for the `stop_cancellation` function."""
|
||||
|
||||
def test_cancellation(self):
|
||||
def test_cancellation(self) -> None:
|
||||
"""Test that cancellation of the new `Deferred` leaves the original running."""
|
||||
deferred: "Deferred[str]" = Deferred()
|
||||
wrapper_deferred = stop_cancellation(deferred)
|
||||
@ -384,7 +392,7 @@ class StopCancellationTests(TestCase):
|
||||
class DelayCancellationTests(TestCase):
|
||||
"""Tests for the `delay_cancellation` function."""
|
||||
|
||||
def test_deferred_cancellation(self):
|
||||
def test_deferred_cancellation(self) -> None:
|
||||
"""Test that cancellation of the new `Deferred` waits for the original."""
|
||||
deferred: "Deferred[str]" = Deferred()
|
||||
wrapper_deferred = delay_cancellation(deferred)
|
||||
@ -405,12 +413,12 @@ class DelayCancellationTests(TestCase):
|
||||
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
|
||||
self.failureResultOf(wrapper_deferred, CancelledError)
|
||||
|
||||
def test_coroutine_cancellation(self):
|
||||
def test_coroutine_cancellation(self) -> None:
|
||||
"""Test that cancellation of the new `Deferred` waits for the original."""
|
||||
blocking_deferred: "Deferred[None]" = Deferred()
|
||||
completion_deferred: "Deferred[None]" = Deferred()
|
||||
|
||||
async def task():
|
||||
async def task() -> NoReturn:
|
||||
await blocking_deferred
|
||||
completion_deferred.callback(None)
|
||||
# Raise an exception. Twisted should consume it, otherwise unwanted
|
||||
@ -434,7 +442,7 @@ class DelayCancellationTests(TestCase):
|
||||
# Now that the original coroutine has failed, we should get a `CancelledError`.
|
||||
self.failureResultOf(wrapper_deferred, CancelledError)
|
||||
|
||||
def test_suppresses_second_cancellation(self):
|
||||
def test_suppresses_second_cancellation(self) -> None:
|
||||
"""Test that a second cancellation is suppressed.
|
||||
|
||||
Identical to `test_cancellation` except the new `Deferred` is cancelled twice.
|
||||
@ -459,7 +467,7 @@ class DelayCancellationTests(TestCase):
|
||||
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
|
||||
self.failureResultOf(wrapper_deferred, CancelledError)
|
||||
|
||||
def test_propagates_cancelled_error(self):
|
||||
def test_propagates_cancelled_error(self) -> None:
|
||||
"""Test that a `CancelledError` from the original `Deferred` gets propagated."""
|
||||
deferred: "Deferred[str]" = Deferred()
|
||||
wrapper_deferred = delay_cancellation(deferred)
|
||||
@ -472,14 +480,14 @@ class DelayCancellationTests(TestCase):
|
||||
self.assertTrue(wrapper_deferred.called)
|
||||
self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value)
|
||||
|
||||
def test_preserves_logcontext(self):
|
||||
def test_preserves_logcontext(self) -> None:
|
||||
"""Test that logging contexts are preserved."""
|
||||
blocking_d: "Deferred[None]" = Deferred()
|
||||
|
||||
async def inner():
|
||||
async def inner() -> None:
|
||||
await make_deferred_yieldable(blocking_d)
|
||||
|
||||
async def outer():
|
||||
async def outer() -> None:
|
||||
with LoggingContext("c") as c:
|
||||
try:
|
||||
await delay_cancellation(inner())
|
||||
@ -503,7 +511,7 @@ class DelayCancellationTests(TestCase):
|
||||
class AwakenableSleeperTests(TestCase):
|
||||
"Tests AwakenableSleeper"
|
||||
|
||||
def test_sleep(self):
|
||||
def test_sleep(self) -> None:
|
||||
reactor, _ = get_clock()
|
||||
sleeper = AwakenableSleeper(reactor)
|
||||
|
||||
@ -518,7 +526,7 @@ class AwakenableSleeperTests(TestCase):
|
||||
reactor.advance(0.6)
|
||||
self.assertTrue(d.called)
|
||||
|
||||
def test_explicit_wake(self):
|
||||
def test_explicit_wake(self) -> None:
|
||||
reactor, _ = get_clock()
|
||||
sleeper = AwakenableSleeper(reactor)
|
||||
|
||||
@ -535,7 +543,7 @@ class AwakenableSleeperTests(TestCase):
|
||||
|
||||
reactor.advance(0.6)
|
||||
|
||||
def test_multiple_sleepers_timeout(self):
|
||||
def test_multiple_sleepers_timeout(self) -> None:
|
||||
reactor, _ = get_clock()
|
||||
sleeper = AwakenableSleeper(reactor)
|
||||
|
||||
@ -555,7 +563,7 @@ class AwakenableSleeperTests(TestCase):
|
||||
reactor.advance(0.6)
|
||||
self.assertTrue(d2.called)
|
||||
|
||||
def test_multiple_sleepers_wake(self):
|
||||
def test_multiple_sleepers_wake(self) -> None:
|
||||
reactor, _ = get_clock()
|
||||
sleeper = AwakenableSleeper(reactor)
|
||||
|
||||
|
@ -11,6 +11,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Tuple
|
||||
|
||||
from prometheus_client import Gauge
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
@ -26,7 +30,7 @@ from tests.unittest import TestCase
|
||||
|
||||
|
||||
class BatchingQueueTestCase(TestCase):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.clock, hs_clock = get_clock()
|
||||
|
||||
# We ensure that we remove any existing metrics for "test_queue".
|
||||
@ -37,25 +41,27 @@ class BatchingQueueTestCase(TestCase):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
self._pending_calls = []
|
||||
self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)
|
||||
self._pending_calls: List[Tuple[List[str], defer.Deferred]] = []
|
||||
self.queue: BatchingQueue[str, str] = BatchingQueue(
|
||||
"test_queue", hs_clock, self._process_queue
|
||||
)
|
||||
|
||||
async def _process_queue(self, values):
|
||||
d = defer.Deferred()
|
||||
async def _process_queue(self, values: List[str]) -> str:
|
||||
d: "defer.Deferred[str]" = defer.Deferred()
|
||||
self._pending_calls.append((values, d))
|
||||
return await make_deferred_yieldable(d)
|
||||
|
||||
def _get_sample_with_name(self, metric, name) -> int:
|
||||
def _get_sample_with_name(self, metric: Gauge, name: str) -> float:
|
||||
"""For a prometheus metric get the value of the sample that has a
|
||||
matching "name" label.
|
||||
"""
|
||||
for sample in metric.collect()[0].samples:
|
||||
for sample in next(iter(metric.collect())).samples:
|
||||
if sample.labels.get("name") == name:
|
||||
return sample.value
|
||||
|
||||
self.fail("Found no matching sample")
|
||||
|
||||
def _assert_metrics(self, queued, keys, in_flight):
|
||||
def _assert_metrics(self, queued: int, keys: int, in_flight: int) -> None:
|
||||
"""Assert that the metrics are correct"""
|
||||
|
||||
sample = self._get_sample_with_name(number_queued, self.queue._name)
|
||||
@ -75,7 +81,7 @@ class BatchingQueueTestCase(TestCase):
|
||||
"number_in_flight",
|
||||
)
|
||||
|
||||
def test_simple(self):
|
||||
def test_simple(self) -> None:
|
||||
"""Tests the basic case of calling `add_to_queue` once and having
|
||||
`_process_queue` return.
|
||||
"""
|
||||
@ -106,7 +112,7 @@ class BatchingQueueTestCase(TestCase):
|
||||
|
||||
self._assert_metrics(queued=0, keys=0, in_flight=0)
|
||||
|
||||
def test_batching(self):
|
||||
def test_batching(self) -> None:
|
||||
"""Test that multiple calls at the same time get batched up into one
|
||||
call to `_process_queue`.
|
||||
"""
|
||||
@ -134,7 +140,7 @@ class BatchingQueueTestCase(TestCase):
|
||||
self.assertEqual(self.successResultOf(queue_d2), "bar")
|
||||
self._assert_metrics(queued=0, keys=0, in_flight=0)
|
||||
|
||||
def test_queuing(self):
|
||||
def test_queuing(self) -> None:
|
||||
"""Test that we queue up requests while a `_process_queue` is being
|
||||
called.
|
||||
"""
|
||||
@ -184,7 +190,7 @@ class BatchingQueueTestCase(TestCase):
|
||||
self.assertEqual(self.successResultOf(queue_d3), "bar2")
|
||||
self._assert_metrics(queued=0, keys=0, in_flight=0)
|
||||
|
||||
def test_different_keys(self):
|
||||
def test_different_keys(self) -> None:
|
||||
"""Test that calls to different keys get processed in parallel."""
|
||||
|
||||
self.assertFalse(self._pending_calls)
|
||||
|
@ -1,5 +1,20 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator, Optional
|
||||
from os import PathLike
|
||||
from typing import Generator, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
from synapse.util.check_dependencies import (
|
||||
@ -12,17 +27,17 @@ from tests.unittest import TestCase
|
||||
|
||||
|
||||
class DummyDistribution(metadata.Distribution):
|
||||
def __init__(self, version: object):
|
||||
def __init__(self, version: str):
|
||||
self._version = version
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
def version(self) -> str:
|
||||
return self._version
|
||||
|
||||
def locate_file(self, path):
|
||||
def locate_file(self, path: Union[str, PathLike]) -> PathLike:
|
||||
raise NotImplementedError()
|
||||
|
||||
def read_text(self, filename):
|
||||
def read_text(self, filename: str) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@ -30,7 +45,7 @@ old = DummyDistribution("0.1.2")
|
||||
old_release_candidate = DummyDistribution("0.1.2rc3")
|
||||
new = DummyDistribution("1.2.3")
|
||||
new_release_candidate = DummyDistribution("1.2.3rc4")
|
||||
distribution_with_no_version = DummyDistribution(None)
|
||||
distribution_with_no_version = DummyDistribution(None) # type: ignore[arg-type]
|
||||
|
||||
# could probably use stdlib TestCase --- no need for twisted here
|
||||
|
||||
@ -45,7 +60,7 @@ class TestDependencyChecker(TestCase):
|
||||
If `distribution = None`, we pretend that the package is not installed.
|
||||
"""
|
||||
|
||||
def mock_distribution(name: str):
|
||||
def mock_distribution(name: str) -> DummyDistribution:
|
||||
if distribution is None:
|
||||
raise metadata.PackageNotFoundError
|
||||
else:
|
||||
|
@ -19,10 +19,12 @@ from tests import unittest
|
||||
|
||||
|
||||
class DictCacheTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.cache = DictionaryCache("foobar", max_entries=10)
|
||||
def setUp(self) -> None:
|
||||
self.cache: DictionaryCache[str, str, str] = DictionaryCache(
|
||||
"foobar", max_entries=10
|
||||
)
|
||||
|
||||
def test_simple_cache_hit_full(self):
|
||||
def test_simple_cache_hit_full(self) -> None:
|
||||
key = "test_simple_cache_hit_full"
|
||||
|
||||
v = self.cache.get(key)
|
||||
@ -37,7 +39,7 @@ class DictCacheTestCase(unittest.TestCase):
|
||||
c = self.cache.get(key)
|
||||
self.assertEqual(test_value, c.value)
|
||||
|
||||
def test_simple_cache_hit_partial(self):
|
||||
def test_simple_cache_hit_partial(self) -> None:
|
||||
key = "test_simple_cache_hit_partial"
|
||||
|
||||
seq = self.cache.sequence
|
||||
@ -47,7 +49,7 @@ class DictCacheTestCase(unittest.TestCase):
|
||||
c = self.cache.get(key, ["test"])
|
||||
self.assertEqual(test_value, c.value)
|
||||
|
||||
def test_simple_cache_miss_partial(self):
|
||||
def test_simple_cache_miss_partial(self) -> None:
|
||||
key = "test_simple_cache_miss_partial"
|
||||
|
||||
seq = self.cache.sequence
|
||||
@ -57,7 +59,7 @@ class DictCacheTestCase(unittest.TestCase):
|
||||
c = self.cache.get(key, ["test2"])
|
||||
self.assertEqual({}, c.value)
|
||||
|
||||
def test_simple_cache_hit_miss_partial(self):
|
||||
def test_simple_cache_hit_miss_partial(self) -> None:
|
||||
key = "test_simple_cache_hit_miss_partial"
|
||||
|
||||
seq = self.cache.sequence
|
||||
@ -71,7 +73,7 @@ class DictCacheTestCase(unittest.TestCase):
|
||||
c = self.cache.get(key, ["test2"])
|
||||
self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
|
||||
|
||||
def test_multi_insert(self):
|
||||
def test_multi_insert(self) -> None:
|
||||
key = "test_simple_cache_hit_miss_partial"
|
||||
|
||||
seq = self.cache.sequence
|
||||
@ -92,7 +94,7 @@ class DictCacheTestCase(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(c.full, False)
|
||||
|
||||
def test_invalidation(self):
|
||||
def test_invalidation(self) -> None:
|
||||
"""Test that the partial dict and full dicts get invalidated
|
||||
separately.
|
||||
"""
|
||||
@ -106,7 +108,7 @@ class DictCacheTestCase(unittest.TestCase):
|
||||
# entry for "a" warm.
|
||||
for i in range(20):
|
||||
self.cache.get(key, ["a"])
|
||||
self.cache.update(seq, f"key{i}", {1: 2})
|
||||
self.cache.update(seq, f"key{i}", {"1": "2"})
|
||||
|
||||
# We should have evicted the full dict...
|
||||
r = self.cache.get(key)
|
||||
|
@ -12,7 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, cast
|
||||
|
||||
from synapse.util import Clock
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
||||
from tests.utils import MockClock
|
||||
@ -21,17 +23,21 @@ from .. import unittest
|
||||
|
||||
|
||||
class ExpiringCacheTestCase(unittest.HomeserverTestCase):
|
||||
def test_get_set(self):
|
||||
def test_get_set(self) -> None:
|
||||
clock = MockClock()
|
||||
cache = ExpiringCache("test", clock, max_len=1)
|
||||
cache: ExpiringCache[str, str] = ExpiringCache(
|
||||
"test", cast(Clock, clock), max_len=1
|
||||
)
|
||||
|
||||
cache["key"] = "value"
|
||||
self.assertEqual(cache.get("key"), "value")
|
||||
self.assertEqual(cache["key"], "value")
|
||||
|
||||
def test_eviction(self):
|
||||
def test_eviction(self) -> None:
|
||||
clock = MockClock()
|
||||
cache = ExpiringCache("test", clock, max_len=2)
|
||||
cache: ExpiringCache[str, str] = ExpiringCache(
|
||||
"test", cast(Clock, clock), max_len=2
|
||||
)
|
||||
|
||||
cache["key"] = "value"
|
||||
cache["key2"] = "value2"
|
||||
@ -43,9 +49,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(cache.get("key2"), "value2")
|
||||
self.assertEqual(cache.get("key3"), "value3")
|
||||
|
||||
def test_iterable_eviction(self):
|
||||
def test_iterable_eviction(self) -> None:
|
||||
clock = MockClock()
|
||||
cache = ExpiringCache("test", clock, max_len=5, iterable=True)
|
||||
cache: ExpiringCache[str, List[int]] = ExpiringCache(
|
||||
"test", cast(Clock, clock), max_len=5, iterable=True
|
||||
)
|
||||
|
||||
cache["key"] = [1]
|
||||
cache["key2"] = [2, 3]
|
||||
@ -61,9 +69,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(cache.get("key3"), [4, 5])
|
||||
self.assertEqual(cache.get("key4"), [6, 7])
|
||||
|
||||
def test_time_eviction(self):
|
||||
def test_time_eviction(self) -> None:
|
||||
clock = MockClock()
|
||||
cache = ExpiringCache("test", clock, expiry_ms=1000)
|
||||
cache: ExpiringCache[str, int] = ExpiringCache(
|
||||
"test", cast(Clock, clock), expiry_ms=1000
|
||||
)
|
||||
|
||||
cache["key"] = 1
|
||||
clock.advance_time(0.5)
|
||||
|
@ -12,22 +12,28 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import threading
|
||||
from io import StringIO
|
||||
from io import BytesIO
|
||||
from typing import BinaryIO, Generator, Optional, cast
|
||||
from unittest.mock import NonCallableMock
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer, reactor as _reactor
|
||||
from twisted.internet.interfaces import IPullProducer
|
||||
|
||||
from synapse.types import ISynapseReactor
|
||||
from synapse.util.file_consumer import BackgroundFileConsumer
|
||||
|
||||
from tests import unittest
|
||||
|
||||
reactor = cast(ISynapseReactor, _reactor)
|
||||
|
||||
|
||||
class FileConsumerTests(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def test_pull_consumer(self):
|
||||
string_file = StringIO()
|
||||
def test_pull_consumer(self) -> Generator["defer.Deferred[object]", object, None]:
|
||||
string_file = BytesIO()
|
||||
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
|
||||
|
||||
try:
|
||||
@ -35,55 +41,57 @@ class FileConsumerTests(unittest.TestCase):
|
||||
|
||||
yield producer.register_with_consumer(consumer)
|
||||
|
||||
yield producer.write_and_wait("Foo")
|
||||
yield producer.write_and_wait(b"Foo")
|
||||
|
||||
self.assertEqual(string_file.getvalue(), "Foo")
|
||||
self.assertEqual(string_file.getvalue(), b"Foo")
|
||||
|
||||
yield producer.write_and_wait("Bar")
|
||||
yield producer.write_and_wait(b"Bar")
|
||||
|
||||
self.assertEqual(string_file.getvalue(), "FooBar")
|
||||
self.assertEqual(string_file.getvalue(), b"FooBar")
|
||||
finally:
|
||||
consumer.unregisterProducer()
|
||||
|
||||
yield consumer.wait()
|
||||
yield consumer.wait() # type: ignore[misc]
|
||||
|
||||
self.assertTrue(string_file.closed)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_push_consumer(self):
|
||||
string_file = BlockingStringWrite()
|
||||
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
|
||||
def test_push_consumer(self) -> Generator["defer.Deferred[object]", object, None]:
|
||||
string_file = BlockingBytesWrite()
|
||||
consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor)
|
||||
|
||||
try:
|
||||
producer = NonCallableMock(spec_set=[])
|
||||
|
||||
consumer.registerProducer(producer, True)
|
||||
|
||||
consumer.write("Foo")
|
||||
yield string_file.wait_for_n_writes(1)
|
||||
consumer.write(b"Foo")
|
||||
yield string_file.wait_for_n_writes(1) # type: ignore[misc]
|
||||
|
||||
self.assertEqual(string_file.buffer, "Foo")
|
||||
self.assertEqual(string_file.buffer, b"Foo")
|
||||
|
||||
consumer.write("Bar")
|
||||
yield string_file.wait_for_n_writes(2)
|
||||
consumer.write(b"Bar")
|
||||
yield string_file.wait_for_n_writes(2) # type: ignore[misc]
|
||||
|
||||
self.assertEqual(string_file.buffer, "FooBar")
|
||||
self.assertEqual(string_file.buffer, b"FooBar")
|
||||
finally:
|
||||
consumer.unregisterProducer()
|
||||
|
||||
yield consumer.wait()
|
||||
yield consumer.wait() # type: ignore[misc]
|
||||
|
||||
self.assertTrue(string_file.closed)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_push_producer_feedback(self):
|
||||
string_file = BlockingStringWrite()
|
||||
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
|
||||
def test_push_producer_feedback(
|
||||
self,
|
||||
) -> Generator["defer.Deferred[object]", object, None]:
|
||||
string_file = BlockingBytesWrite()
|
||||
consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor)
|
||||
|
||||
try:
|
||||
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
|
||||
|
||||
resume_deferred = defer.Deferred()
|
||||
resume_deferred: defer.Deferred = defer.Deferred()
|
||||
producer.resumeProducing.side_effect = lambda: resume_deferred.callback(
|
||||
None
|
||||
)
|
||||
@ -93,65 +101,72 @@ class FileConsumerTests(unittest.TestCase):
|
||||
number_writes = 0
|
||||
with string_file.write_lock:
|
||||
for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
|
||||
consumer.write("Foo")
|
||||
consumer.write(b"Foo")
|
||||
number_writes += 1
|
||||
|
||||
producer.pauseProducing.assert_called_once()
|
||||
|
||||
yield string_file.wait_for_n_writes(number_writes)
|
||||
yield string_file.wait_for_n_writes(number_writes) # type: ignore[misc]
|
||||
|
||||
yield resume_deferred
|
||||
producer.resumeProducing.assert_called_once()
|
||||
finally:
|
||||
consumer.unregisterProducer()
|
||||
|
||||
yield consumer.wait()
|
||||
yield consumer.wait() # type: ignore[misc]
|
||||
|
||||
self.assertTrue(string_file.closed)
|
||||
|
||||
|
||||
@implementer(IPullProducer)
|
||||
class DummyPullProducer:
|
||||
def __init__(self):
|
||||
self.consumer = None
|
||||
self.deferred = defer.Deferred()
|
||||
def __init__(self) -> None:
|
||||
self.consumer: Optional[BackgroundFileConsumer] = None
|
||||
self.deferred: "defer.Deferred[object]" = defer.Deferred()
|
||||
|
||||
def resumeProducing(self):
|
||||
def resumeProducing(self) -> None:
|
||||
d = self.deferred
|
||||
self.deferred = defer.Deferred()
|
||||
d.callback(None)
|
||||
|
||||
def write_and_wait(self, bytes):
|
||||
def stopProducing(self) -> None:
|
||||
raise RuntimeError("Unexpected call")
|
||||
|
||||
def write_and_wait(self, write_bytes: bytes) -> "defer.Deferred[object]":
|
||||
assert self.consumer is not None
|
||||
d = self.deferred
|
||||
self.consumer.write(bytes)
|
||||
self.consumer.write(write_bytes)
|
||||
return d
|
||||
|
||||
def register_with_consumer(self, consumer):
|
||||
def register_with_consumer(
|
||||
self, consumer: BackgroundFileConsumer
|
||||
) -> "defer.Deferred[object]":
|
||||
d = self.deferred
|
||||
self.consumer = consumer
|
||||
self.consumer.registerProducer(self, False)
|
||||
return d
|
||||
|
||||
|
||||
class BlockingStringWrite:
|
||||
def __init__(self):
|
||||
self.buffer = ""
|
||||
class BlockingBytesWrite:
|
||||
def __init__(self) -> None:
|
||||
self.buffer = b""
|
||||
self.closed = False
|
||||
self.write_lock = threading.Lock()
|
||||
|
||||
self._notify_write_deferred = None
|
||||
self._notify_write_deferred: Optional[defer.Deferred] = None
|
||||
self._number_of_writes = 0
|
||||
|
||||
def write(self, bytes):
|
||||
def write(self, write_bytes: bytes) -> None:
|
||||
with self.write_lock:
|
||||
self.buffer += bytes
|
||||
self.buffer += write_bytes
|
||||
self._number_of_writes += 1
|
||||
|
||||
reactor.callFromThread(self._notify_write)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
def _notify_write(self):
|
||||
def _notify_write(self) -> None:
|
||||
"Called by write to indicate a write happened"
|
||||
with self.write_lock:
|
||||
if not self._notify_write_deferred:
|
||||
@ -161,7 +176,9 @@ class BlockingStringWrite:
|
||||
d.callback(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def wait_for_n_writes(self, n):
|
||||
def wait_for_n_writes(
|
||||
self, n: int
|
||||
) -> Generator["defer.Deferred[object]", object, None]:
|
||||
"Wait for n writes to have happened"
|
||||
while True:
|
||||
with self.write_lock:
|
||||
|
@ -19,7 +19,7 @@ from tests.unittest import TestCase
|
||||
|
||||
|
||||
class ChunkSeqTests(TestCase):
|
||||
def test_short_seq(self):
|
||||
def test_short_seq(self) -> None:
|
||||
parts = chunk_seq("123", 8)
|
||||
|
||||
self.assertEqual(
|
||||
@ -27,7 +27,7 @@ class ChunkSeqTests(TestCase):
|
||||
["123"],
|
||||
)
|
||||
|
||||
def test_long_seq(self):
|
||||
def test_long_seq(self) -> None:
|
||||
parts = chunk_seq("abcdefghijklmnop", 8)
|
||||
|
||||
self.assertEqual(
|
||||
@ -35,7 +35,7 @@ class ChunkSeqTests(TestCase):
|
||||
["abcdefgh", "ijklmnop"],
|
||||
)
|
||||
|
||||
def test_uneven_parts(self):
|
||||
def test_uneven_parts(self) -> None:
|
||||
parts = chunk_seq("abcdefghijklmnop", 5)
|
||||
|
||||
self.assertEqual(
|
||||
@ -43,7 +43,7 @@ class ChunkSeqTests(TestCase):
|
||||
["abcde", "fghij", "klmno", "p"],
|
||||
)
|
||||
|
||||
def test_empty_input(self):
|
||||
def test_empty_input(self) -> None:
|
||||
parts: Iterable[Sequence] = chunk_seq([], 5)
|
||||
|
||||
self.assertEqual(
|
||||
@ -53,13 +53,13 @@ class ChunkSeqTests(TestCase):
|
||||
|
||||
|
||||
class SortTopologically(TestCase):
|
||||
def test_empty(self):
|
||||
def test_empty(self) -> None:
|
||||
"Test that an empty graph works correctly"
|
||||
|
||||
graph: Dict[int, List[int]] = {}
|
||||
self.assertEqual(list(sorted_topologically([], graph)), [])
|
||||
|
||||
def test_handle_empty_graph(self):
|
||||
def test_handle_empty_graph(self) -> None:
|
||||
"Test that a graph where a node doesn't have an entry is treated as empty"
|
||||
|
||||
graph: Dict[int, List[int]] = {}
|
||||
@ -67,7 +67,7 @@ class SortTopologically(TestCase):
|
||||
# For disconnected nodes the output is simply sorted.
|
||||
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
|
||||
|
||||
def test_disconnected(self):
|
||||
def test_disconnected(self) -> None:
|
||||
"Test that a graph with no edges work"
|
||||
|
||||
graph: Dict[int, List[int]] = {1: [], 2: []}
|
||||
@ -75,20 +75,20 @@ class SortTopologically(TestCase):
|
||||
# For disconnected nodes the output is simply sorted.
|
||||
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
|
||||
|
||||
def test_linear(self):
|
||||
def test_linear(self) -> None:
|
||||
"Test that a simple `4 -> 3 -> 2 -> 1` graph works"
|
||||
|
||||
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
|
||||
|
||||
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
|
||||
|
||||
def test_subset(self):
|
||||
def test_subset(self) -> None:
|
||||
"Test that only sorting a subset of the graph works"
|
||||
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
|
||||
|
||||
self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
|
||||
|
||||
def test_fork(self):
|
||||
def test_fork(self) -> None:
|
||||
"Test that a forked graph works"
|
||||
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]}
|
||||
|
||||
@ -96,13 +96,13 @@ class SortTopologically(TestCase):
|
||||
# always get the same one.
|
||||
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
|
||||
|
||||
def test_duplicates(self):
|
||||
def test_duplicates(self) -> None:
|
||||
"Test that a graph with duplicate edges work"
|
||||
graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]}
|
||||
|
||||
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
|
||||
|
||||
def test_multiple_paths(self):
|
||||
def test_multiple_paths(self) -> None:
|
||||
"Test that a graph with multiple paths between two nodes work"
|
||||
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]}
|
||||
|
||||
|
@ -1,5 +1,21 @@
|
||||
# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Generator, cast
|
||||
|
||||
import twisted.python.failure
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet import defer, reactor as _reactor
|
||||
|
||||
from synapse.logging.context import (
|
||||
SENTINEL_CONTEXT,
|
||||
@ -10,25 +26,30 @@ from synapse.logging.context import (
|
||||
nested_logging_context,
|
||||
run_in_background,
|
||||
)
|
||||
from synapse.types import ISynapseReactor
|
||||
from synapse.util import Clock
|
||||
|
||||
from .. import unittest
|
||||
|
||||
reactor = cast(ISynapseReactor, _reactor)
|
||||
|
||||
|
||||
class LoggingContextTestCase(unittest.TestCase):
|
||||
def _check_test_key(self, value):
|
||||
self.assertEqual(current_context().name, value)
|
||||
def _check_test_key(self, value: str) -> None:
|
||||
context = current_context()
|
||||
assert isinstance(context, LoggingContext)
|
||||
self.assertEqual(context.name, value)
|
||||
|
||||
def test_with_context(self):
|
||||
def test_with_context(self) -> None:
|
||||
with LoggingContext("test"):
|
||||
self._check_test_key("test")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_sleep(self):
|
||||
def test_sleep(self) -> Generator["defer.Deferred[object]", object, None]:
|
||||
clock = Clock(reactor)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def competing_callback():
|
||||
def competing_callback() -> Generator["defer.Deferred[object]", object, None]:
|
||||
with LoggingContext("competing"):
|
||||
yield clock.sleep(0)
|
||||
self._check_test_key("competing")
|
||||
@ -39,17 +60,18 @@ class LoggingContextTestCase(unittest.TestCase):
|
||||
yield clock.sleep(0)
|
||||
self._check_test_key("one")
|
||||
|
||||
def _test_run_in_background(self, function):
|
||||
def _test_run_in_background(self, function: Callable[[], object]) -> defer.Deferred:
|
||||
sentinel_context = current_context()
|
||||
|
||||
callback_completed = [False]
|
||||
callback_completed = False
|
||||
|
||||
with LoggingContext("one"):
|
||||
# fire off function, but don't wait on it.
|
||||
d2 = run_in_background(function)
|
||||
|
||||
def cb(res):
|
||||
callback_completed[0] = True
|
||||
def cb(res: object) -> object:
|
||||
nonlocal callback_completed
|
||||
callback_completed = True
|
||||
return res
|
||||
|
||||
d2.addCallback(cb)
|
||||
@ -60,8 +82,8 @@ class LoggingContextTestCase(unittest.TestCase):
|
||||
# the logcontext is left in a sane state.
|
||||
d2 = defer.Deferred()
|
||||
|
||||
def check_logcontext():
|
||||
if not callback_completed[0]:
|
||||
def check_logcontext() -> None:
|
||||
if not callback_completed:
|
||||
reactor.callLater(0.01, check_logcontext)
|
||||
return
|
||||
|
||||
@ -78,31 +100,31 @@ class LoggingContextTestCase(unittest.TestCase):
|
||||
# test is done once d2 finishes
|
||||
return d2
|
||||
|
||||
def test_run_in_background_with_blocking_fn(self):
|
||||
def test_run_in_background_with_blocking_fn(self) -> defer.Deferred:
|
||||
@defer.inlineCallbacks
|
||||
def blocking_function():
|
||||
def blocking_function() -> Generator["defer.Deferred[object]", object, None]:
|
||||
yield Clock(reactor).sleep(0)
|
||||
|
||||
return self._test_run_in_background(blocking_function)
|
||||
|
||||
def test_run_in_background_with_non_blocking_fn(self):
|
||||
def test_run_in_background_with_non_blocking_fn(self) -> defer.Deferred:
|
||||
@defer.inlineCallbacks
|
||||
def nonblocking_function():
|
||||
def nonblocking_function() -> Generator["defer.Deferred[object]", object, None]:
|
||||
with PreserveLoggingContext():
|
||||
yield defer.succeed(None)
|
||||
|
||||
return self._test_run_in_background(nonblocking_function)
|
||||
|
||||
def test_run_in_background_with_chained_deferred(self):
|
||||
def test_run_in_background_with_chained_deferred(self) -> defer.Deferred:
|
||||
# a function which returns a deferred which looks like it has been
|
||||
# called, but is actually paused
|
||||
def testfunc():
|
||||
def testfunc() -> defer.Deferred:
|
||||
return make_deferred_yieldable(_chained_deferred_function())
|
||||
|
||||
return self._test_run_in_background(testfunc)
|
||||
|
||||
def test_run_in_background_with_coroutine(self):
|
||||
async def testfunc():
|
||||
def test_run_in_background_with_coroutine(self) -> defer.Deferred:
|
||||
async def testfunc() -> None:
|
||||
self._check_test_key("one")
|
||||
d = Clock(reactor).sleep(0)
|
||||
self.assertIs(current_context(), SENTINEL_CONTEXT)
|
||||
@ -111,18 +133,20 @@ class LoggingContextTestCase(unittest.TestCase):
|
||||
|
||||
return self._test_run_in_background(testfunc)
|
||||
|
||||
def test_run_in_background_with_nonblocking_coroutine(self):
|
||||
async def testfunc():
|
||||
def test_run_in_background_with_nonblocking_coroutine(self) -> defer.Deferred:
|
||||
async def testfunc() -> None:
|
||||
self._check_test_key("one")
|
||||
|
||||
return self._test_run_in_background(testfunc)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_make_deferred_yieldable(self):
|
||||
def test_make_deferred_yieldable(
|
||||
self,
|
||||
) -> Generator["defer.Deferred[object]", object, None]:
|
||||
# a function which returns an incomplete deferred, but doesn't follow
|
||||
# the synapse rules.
|
||||
def blocking_function():
|
||||
d = defer.Deferred()
|
||||
def blocking_function() -> defer.Deferred:
|
||||
d: defer.Deferred = defer.Deferred()
|
||||
reactor.callLater(0, d.callback, None)
|
||||
return d
|
||||
|
||||
@ -139,7 +163,9 @@ class LoggingContextTestCase(unittest.TestCase):
|
||||
self._check_test_key("one")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_make_deferred_yieldable_with_chained_deferreds(self):
|
||||
def test_make_deferred_yieldable_with_chained_deferreds(
|
||||
self,
|
||||
) -> Generator["defer.Deferred[object]", object, None]:
|
||||
sentinel_context = current_context()
|
||||
|
||||
with LoggingContext("one"):
|
||||
@ -152,7 +178,7 @@ class LoggingContextTestCase(unittest.TestCase):
|
||||
# now it should be restored
|
||||
self._check_test_key("one")
|
||||
|
||||
def test_nested_logging_context(self):
|
||||
def test_nested_logging_context(self) -> None:
|
||||
with LoggingContext("foo"):
|
||||
nested_context = nested_logging_context(suffix="bar")
|
||||
self.assertEqual(nested_context.name, "foo-bar")
|
||||
@ -161,11 +187,11 @@ class LoggingContextTestCase(unittest.TestCase):
|
||||
# a function which returns a deferred which has been "called", but
|
||||
# which had a function which returned another incomplete deferred on
|
||||
# its callback list, so won't yet call any other new callbacks.
|
||||
def _chained_deferred_function():
|
||||
def _chained_deferred_function() -> defer.Deferred:
|
||||
d = defer.succeed(None)
|
||||
|
||||
def cb(res):
|
||||
d2 = defer.Deferred()
|
||||
def cb(res: object) -> defer.Deferred:
|
||||
d2: defer.Deferred = defer.Deferred()
|
||||
reactor.callLater(0, d2.callback, res)
|
||||
return d2
|
||||
|
||||
|
@ -23,7 +23,7 @@ class TestException(Exception):
|
||||
|
||||
|
||||
class LogFormatterTestCase(unittest.TestCase):
|
||||
def test_formatter(self):
|
||||
def test_formatter(self) -> None:
|
||||
formatter = LogFormatter()
|
||||
|
||||
try:
|
||||
|
@ -13,10 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from synapse.metrics.jemalloc import JemallocStats
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries
|
||||
from synapse.util.caches.treecache import TreeCache
|
||||
|
||||
@ -25,14 +26,14 @@ from tests.unittest import override_config
|
||||
|
||||
|
||||
class LruCacheTestCase(unittest.HomeserverTestCase):
|
||||
def test_get_set(self):
|
||||
cache = LruCache(1)
|
||||
def test_get_set(self) -> None:
|
||||
cache: LruCache[str, str] = LruCache(1)
|
||||
cache["key"] = "value"
|
||||
self.assertEqual(cache.get("key"), "value")
|
||||
self.assertEqual(cache["key"], "value")
|
||||
|
||||
def test_eviction(self):
|
||||
cache = LruCache(2)
|
||||
def test_eviction(self) -> None:
|
||||
cache: LruCache[int, int] = LruCache(2)
|
||||
cache[1] = 1
|
||||
cache[2] = 2
|
||||
|
||||
@ -45,8 +46,8 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(cache.get(2), 2)
|
||||
self.assertEqual(cache.get(3), 3)
|
||||
|
||||
def test_setdefault(self):
|
||||
cache = LruCache(1)
|
||||
def test_setdefault(self) -> None:
|
||||
cache: LruCache[str, int] = LruCache(1)
|
||||
self.assertEqual(cache.setdefault("key", 1), 1)
|
||||
self.assertEqual(cache.get("key"), 1)
|
||||
self.assertEqual(cache.setdefault("key", 2), 1)
|
||||
@ -54,14 +55,15 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
|
||||
cache["key"] = 2 # Make sure overriding works.
|
||||
self.assertEqual(cache.get("key"), 2)
|
||||
|
||||
def test_pop(self):
|
||||
cache = LruCache(1)
|
||||
def test_pop(self) -> None:
|
||||
cache: LruCache[str, int] = LruCache(1)
|
||||
cache["key"] = 1
|
||||
self.assertEqual(cache.pop("key"), 1)
|
||||
self.assertEqual(cache.pop("key"), None)
|
||||
|
||||
def test_del_multi(self):
|
||||
cache = LruCache(4, cache_type=TreeCache)
|
||||
def test_del_multi(self) -> None:
|
||||
# The type here isn't quite correct as they don't handle TreeCache well.
|
||||
cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache)
|
||||
cache[("animal", "cat")] = "mew"
|
||||
cache[("animal", "dog")] = "woof"
|
||||
cache[("vehicles", "car")] = "vroom"
|
||||
@ -71,7 +73,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual(cache.get(("animal", "cat")), "mew")
|
||||
self.assertEqual(cache.get(("vehicles", "car")), "vroom")
|
||||
cache.del_multi(("animal",))
|
||||
cache.del_multi(("animal",)) # type: ignore[arg-type]
|
||||
self.assertEqual(len(cache), 2)
|
||||
self.assertEqual(cache.get(("animal", "cat")), None)
|
||||
self.assertEqual(cache.get(("animal", "dog")), None)
|
||||
@ -79,22 +81,22 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(cache.get(("vehicles", "train")), "chuff")
|
||||
# Man from del_multi say "Yes".
|
||||
|
||||
def test_clear(self):
|
||||
cache = LruCache(1)
|
||||
def test_clear(self) -> None:
|
||||
cache: LruCache[str, int] = LruCache(1)
|
||||
cache["key"] = 1
|
||||
cache.clear()
|
||||
self.assertEqual(len(cache), 0)
|
||||
|
||||
@override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
|
||||
def test_special_size(self):
|
||||
cache = LruCache(10, "mycache")
|
||||
def test_special_size(self) -> None:
|
||||
cache: LruCache = LruCache(10, "mycache")
|
||||
self.assertEqual(cache.max_size, 100)
|
||||
|
||||
|
||||
class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
|
||||
def test_get(self):
|
||||
def test_get(self) -> None:
|
||||
m = Mock()
|
||||
cache = LruCache(1)
|
||||
cache: LruCache[str, str] = LruCache(1)
|
||||
|
||||
cache.set("key", "value")
|
||||
self.assertFalse(m.called)
|
||||
@ -111,9 +113,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
|
||||
cache.set("key", "value")
|
||||
self.assertEqual(m.call_count, 1)
|
||||
|
||||
def test_multi_get(self):
|
||||
def test_multi_get(self) -> None:
|
||||
m = Mock()
|
||||
cache = LruCache(1)
|
||||
cache: LruCache[str, str] = LruCache(1)
|
||||
|
||||
cache.set("key", "value")
|
||||
self.assertFalse(m.called)
|
||||
@ -130,9 +132,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
|
||||
cache.set("key", "value")
|
||||
self.assertEqual(m.call_count, 1)
|
||||
|
||||
def test_set(self):
|
||||
def test_set(self) -> None:
|
||||
m = Mock()
|
||||
cache = LruCache(1)
|
||||
cache: LruCache[str, str] = LruCache(1)
|
||||
|
||||
cache.set("key", "value", callbacks=[m])
|
||||
self.assertFalse(m.called)
|
||||
@ -146,9 +148,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
|
||||
cache.set("key", "value")
|
||||
self.assertEqual(m.call_count, 1)
|
||||
|
||||
def test_pop(self):
|
||||
def test_pop(self) -> None:
|
||||
m = Mock()
|
||||
cache = LruCache(1)
|
||||
cache: LruCache[str, str] = LruCache(1)
|
||||
|
||||
cache.set("key", "value", callbacks=[m])
|
||||
self.assertFalse(m.called)
|
||||
@ -162,12 +164,13 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
|
||||
cache.pop("key")
|
||||
self.assertEqual(m.call_count, 1)
|
||||
|
||||
def test_del_multi(self):
|
||||
def test_del_multi(self) -> None:
|
||||
m1 = Mock()
|
||||
m2 = Mock()
|
||||
m3 = Mock()
|
||||
m4 = Mock()
|
||||
cache = LruCache(4, cache_type=TreeCache)
|
||||
# The type here isn't quite correct as they don't handle TreeCache well.
|
||||
cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache)
|
||||
|
||||
cache.set(("a", "1"), "value", callbacks=[m1])
|
||||
cache.set(("a", "2"), "value", callbacks=[m2])
|
||||
@ -179,17 +182,17 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(m3.call_count, 0)
|
||||
self.assertEqual(m4.call_count, 0)
|
||||
|
||||
cache.del_multi(("a",))
|
||||
cache.del_multi(("a",)) # type: ignore[arg-type]
|
||||
|
||||
self.assertEqual(m1.call_count, 1)
|
||||
self.assertEqual(m2.call_count, 1)
|
||||
self.assertEqual(m3.call_count, 0)
|
||||
self.assertEqual(m4.call_count, 0)
|
||||
|
||||
def test_clear(self):
|
||||
def test_clear(self) -> None:
|
||||
m1 = Mock()
|
||||
m2 = Mock()
|
||||
cache = LruCache(5)
|
||||
cache: LruCache[str, str] = LruCache(5)
|
||||
|
||||
cache.set("key1", "value", callbacks=[m1])
|
||||
cache.set("key2", "value", callbacks=[m2])
|
||||
@ -202,11 +205,11 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(m1.call_count, 1)
|
||||
self.assertEqual(m2.call_count, 1)
|
||||
|
||||
def test_eviction(self):
|
||||
def test_eviction(self) -> None:
|
||||
m1 = Mock(name="m1")
|
||||
m2 = Mock(name="m2")
|
||||
m3 = Mock(name="m3")
|
||||
cache = LruCache(2)
|
||||
cache: LruCache[str, str] = LruCache(2)
|
||||
|
||||
cache.set("key1", "value", callbacks=[m1])
|
||||
cache.set("key2", "value", callbacks=[m2])
|
||||
@ -241,8 +244,8 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
|
||||
class LruCacheSizedTestCase(unittest.HomeserverTestCase):
|
||||
def test_evict(self):
|
||||
cache = LruCache(5, size_callback=len)
|
||||
def test_evict(self) -> None:
|
||||
cache: LruCache[str, List[int]] = LruCache(5, size_callback=len)
|
||||
cache["key1"] = [0]
|
||||
cache["key2"] = [1, 2]
|
||||
cache["key3"] = [3]
|
||||
@ -269,6 +272,7 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
|
||||
cache["key1"] = []
|
||||
|
||||
self.assertEqual(len(cache), 0)
|
||||
assert isinstance(cache.cache, dict)
|
||||
cache.cache["key1"].drop_from_cache()
|
||||
self.assertIsNone(
|
||||
cache.pop("key1"), "Cache entry should have been evicted but wasn't"
|
||||
@ -278,17 +282,17 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
|
||||
class TimeEvictionTestCase(unittest.HomeserverTestCase):
|
||||
"""Test that time based eviction works correctly."""
|
||||
|
||||
def default_config(self):
|
||||
def default_config(self) -> JsonDict:
|
||||
config = super().default_config()
|
||||
|
||||
config.setdefault("caches", {})["expiry_time"] = "30m"
|
||||
|
||||
return config
|
||||
|
||||
def test_evict(self):
|
||||
def test_evict(self) -> None:
|
||||
setup_expire_lru_cache_entries(self.hs)
|
||||
|
||||
cache = LruCache(5, clock=self.hs.get_clock())
|
||||
cache: LruCache[str, int] = LruCache(5, clock=self.hs.get_clock())
|
||||
|
||||
# Check that we evict entries we haven't accessed for 30 minutes.
|
||||
cache["key1"] = 1
|
||||
@ -332,7 +336,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
|
||||
}
|
||||
)
|
||||
@patch("synapse.util.caches.lrucache.get_jemalloc_stats")
|
||||
def test_evict_memory(self, jemalloc_interface) -> None:
|
||||
def test_evict_memory(self, jemalloc_interface: Mock) -> None:
|
||||
mock_jemalloc_class = Mock(spec=JemallocStats)
|
||||
jemalloc_interface.return_value = mock_jemalloc_class
|
||||
|
||||
@ -340,7 +344,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
|
||||
mock_jemalloc_class.get_stat.return_value = 924288000
|
||||
|
||||
setup_expire_lru_cache_entries(self.hs)
|
||||
cache = LruCache(4, clock=self.hs.get_clock())
|
||||
cache: LruCache[str, int] = LruCache(4, clock=self.hs.get_clock())
|
||||
|
||||
cache["key1"] = 1
|
||||
cache["key2"] = 2
|
||||
|
@ -21,14 +21,14 @@ from tests.unittest import TestCase
|
||||
|
||||
|
||||
class MacaroonGeneratorTestCase(TestCase):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.reactor, hs_clock = get_clock()
|
||||
self.macaroon_generator = MacaroonGenerator(hs_clock, "tesths", b"verysecret")
|
||||
self.other_macaroon_generator = MacaroonGenerator(
|
||||
hs_clock, "tesths", b"anothersecretkey"
|
||||
)
|
||||
|
||||
def test_guest_access_token(self):
|
||||
def test_guest_access_token(self) -> None:
|
||||
"""Test the generation and verification of guest access tokens"""
|
||||
token = self.macaroon_generator.generate_guest_access_token("@user:tesths")
|
||||
user_id = self.macaroon_generator.verify_guest_token(token)
|
||||
@ -47,7 +47,7 @@ class MacaroonGeneratorTestCase(TestCase):
|
||||
with self.assertRaises(MacaroonVerificationFailedException):
|
||||
self.macaroon_generator.verify_guest_token(token)
|
||||
|
||||
def test_delete_pusher_token(self):
|
||||
def test_delete_pusher_token(self) -> None:
|
||||
"""Test the generation and verification of delete_pusher tokens"""
|
||||
token = self.macaroon_generator.generate_delete_pusher_token(
|
||||
"@user:tesths", "m.mail", "john@example.com"
|
||||
@ -84,7 +84,7 @@ class MacaroonGeneratorTestCase(TestCase):
|
||||
)
|
||||
self.assertEqual(user_id, "@user:tesths")
|
||||
|
||||
def test_oidc_session_token(self):
|
||||
def test_oidc_session_token(self) -> None:
|
||||
"""Test the generation and verification of OIDC session cookies"""
|
||||
state = "arandomstate"
|
||||
session_data = OidcSessionData(
|
||||
|
@ -13,16 +13,19 @@
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.ratelimiting import FederationRatelimitSettings
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
|
||||
from tests.server import get_clock
|
||||
from tests.server import ThreadedMemoryReactorClock, get_clock
|
||||
from tests.unittest import TestCase
|
||||
from tests.utils import default_config
|
||||
|
||||
|
||||
class FederationRateLimiterTestCase(TestCase):
|
||||
def test_ratelimit(self):
|
||||
def test_ratelimit(self) -> None:
|
||||
"""A simple test with the default values"""
|
||||
reactor, clock = get_clock()
|
||||
rc_config = build_rc_config()
|
||||
@ -32,7 +35,7 @@ class FederationRateLimiterTestCase(TestCase):
|
||||
# shouldn't block
|
||||
self.successResultOf(d1)
|
||||
|
||||
def test_concurrent_limit(self):
|
||||
def test_concurrent_limit(self) -> None:
|
||||
"""Test what happens when we hit the concurrent limit"""
|
||||
reactor, clock = get_clock()
|
||||
rc_config = build_rc_config({"rc_federation": {"concurrent": 2}})
|
||||
@ -56,7 +59,7 @@ class FederationRateLimiterTestCase(TestCase):
|
||||
cm2.__exit__(None, None, None)
|
||||
self.successResultOf(d3)
|
||||
|
||||
def test_sleep_limit(self):
|
||||
def test_sleep_limit(self) -> None:
|
||||
"""Test what happens when we hit the sleep limit"""
|
||||
reactor, clock = get_clock()
|
||||
rc_config = build_rc_config(
|
||||
@ -79,7 +82,7 @@ class FederationRateLimiterTestCase(TestCase):
|
||||
self.assertAlmostEqual(sleep_time, 500, places=3)
|
||||
|
||||
|
||||
def _await_resolution(reactor, d):
|
||||
def _await_resolution(reactor: ThreadedMemoryReactorClock, d: Deferred) -> float:
|
||||
"""advance the clock until the deferred completes.
|
||||
|
||||
Returns the number of milliseconds it took to complete.
|
||||
@ -90,7 +93,7 @@ def _await_resolution(reactor, d):
|
||||
return (reactor.seconds() - start_time) * 1000
|
||||
|
||||
|
||||
def build_rc_config(settings: Optional[dict] = None):
|
||||
def build_rc_config(settings: Optional[dict] = None) -> FederationRatelimitSettings:
|
||||
config_dict = default_config("test")
|
||||
config_dict.update(settings or {})
|
||||
config = HomeServerConfig()
|
||||
|
@ -22,7 +22,7 @@ from tests.unittest import HomeserverTestCase
|
||||
|
||||
|
||||
class RetryLimiterTestCase(HomeserverTestCase):
|
||||
def test_new_destination(self):
|
||||
def test_new_destination(self) -> None:
|
||||
"""A happy-path case with a new destination and a successful operation"""
|
||||
store = self.hs.get_datastores().main
|
||||
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
|
||||
@ -36,7 +36,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
|
||||
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
|
||||
self.assertIsNone(new_timings)
|
||||
|
||||
def test_limiter(self):
|
||||
def test_limiter(self) -> None:
|
||||
"""General test case which walks through the process of a failing request"""
|
||||
store = self.hs.get_datastores().main
|
||||
|
||||
|
@ -49,7 +49,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
||||
acquired_d: "Deferred[None]" = Deferred()
|
||||
unblock_d: "Deferred[None]" = Deferred()
|
||||
|
||||
async def reader_or_writer():
|
||||
async def reader_or_writer() -> str:
|
||||
async with read_or_write(key):
|
||||
acquired_d.callback(None)
|
||||
await unblock_d
|
||||
@ -134,7 +134,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
||||
d.called, msg="deferred %d was unexpectedly resolved" % (i + n)
|
||||
)
|
||||
|
||||
def test_rwlock(self):
|
||||
def test_rwlock(self) -> None:
|
||||
rwlock = ReadWriteLock()
|
||||
key = "key"
|
||||
|
||||
@ -197,7 +197,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
||||
_, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader")
|
||||
self.assertTrue(acquired_d.called)
|
||||
|
||||
def test_lock_handoff_to_nonblocking_writer(self):
|
||||
def test_lock_handoff_to_nonblocking_writer(self) -> None:
|
||||
"""Test a writer handing the lock to another writer that completes instantly."""
|
||||
rwlock = ReadWriteLock()
|
||||
key = "key"
|
||||
@ -216,7 +216,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
||||
d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed")
|
||||
self.assertTrue(d3.called)
|
||||
|
||||
def test_cancellation_while_holding_read_lock(self):
|
||||
def test_cancellation_while_holding_read_lock(self) -> None:
|
||||
"""Test cancellation while holding a read lock.
|
||||
|
||||
A waiting writer should be given the lock when the reader holding the lock is
|
||||
@ -242,7 +242,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual("write completed", self.successResultOf(writer_d))
|
||||
|
||||
def test_cancellation_while_holding_write_lock(self):
|
||||
def test_cancellation_while_holding_write_lock(self) -> None:
|
||||
"""Test cancellation while holding a write lock.
|
||||
|
||||
A waiting reader should be given the lock when the writer holding the lock is
|
||||
@ -268,7 +268,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual("read completed", self.successResultOf(reader_d))
|
||||
|
||||
def test_cancellation_while_waiting_for_read_lock(self):
|
||||
def test_cancellation_while_waiting_for_read_lock(self) -> None:
|
||||
"""Test cancellation while waiting for a read lock.
|
||||
|
||||
Tests that cancelling a waiting reader:
|
||||
@ -319,7 +319,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual("write 2 completed", self.successResultOf(writer2_d))
|
||||
|
||||
def test_cancellation_while_waiting_for_write_lock(self):
|
||||
def test_cancellation_while_waiting_for_write_lock(self) -> None:
|
||||
"""Test cancellation while waiting for a write lock.
|
||||
|
||||
Tests that cancelling a waiting writer:
|
||||
|
@ -8,7 +8,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
||||
Tests for StreamChangeCache.
|
||||
"""
|
||||
|
||||
def test_prefilled_cache(self):
|
||||
def test_prefilled_cache(self) -> None:
|
||||
"""
|
||||
Providing a prefilled cache to StreamChangeCache will result in a cache
|
||||
with the prefilled-cache entered in.
|
||||
@ -16,7 +16,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
||||
cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2})
|
||||
self.assertTrue(cache.has_entity_changed("user@foo.com", 1))
|
||||
|
||||
def test_has_entity_changed(self):
|
||||
def test_has_entity_changed(self) -> None:
|
||||
"""
|
||||
StreamChangeCache.entity_has_changed will mark entities as changed, and
|
||||
has_entity_changed will observe the changed entities.
|
||||
@ -52,7 +52,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
||||
self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
|
||||
self.assertTrue(cache.has_entity_changed("not@here.website", 0))
|
||||
|
||||
def test_entity_has_changed_pops_off_start(self):
|
||||
def test_entity_has_changed_pops_off_start(self) -> None:
|
||||
"""
|
||||
StreamChangeCache.entity_has_changed will respect the max size and
|
||||
purge the oldest items upon reaching that max size.
|
||||
@ -86,7 +86,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertIsNone(cache.get_all_entities_changed(1))
|
||||
|
||||
def test_get_all_entities_changed(self):
|
||||
def test_get_all_entities_changed(self) -> None:
|
||||
"""
|
||||
StreamChangeCache.get_all_entities_changed will return all changed
|
||||
entities since the given position. If the position is before the start
|
||||
@ -142,7 +142,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
||||
r = cache.get_all_entities_changed(3)
|
||||
self.assertTrue(r == ok1 or r == ok2)
|
||||
|
||||
def test_has_any_entity_changed(self):
|
||||
def test_has_any_entity_changed(self) -> None:
|
||||
"""
|
||||
StreamChangeCache.has_any_entity_changed will return True if any
|
||||
entities have been changed since the provided stream position, and
|
||||
@ -168,7 +168,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
||||
self.assertFalse(cache.has_any_entity_changed(2))
|
||||
self.assertFalse(cache.has_any_entity_changed(3))
|
||||
|
||||
def test_get_entities_changed(self):
|
||||
def test_get_entities_changed(self) -> None:
|
||||
"""
|
||||
StreamChangeCache.get_entities_changed will return the entities in the
|
||||
given list that have changed since the provided stream ID. If the
|
||||
@ -228,7 +228,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
||||
{"bar@baz.net"},
|
||||
)
|
||||
|
||||
def test_max_pos(self):
|
||||
def test_max_pos(self) -> None:
|
||||
"""
|
||||
StreamChangeCache.get_max_pos_of_last_change will return the most
|
||||
recent point where the entity could have changed. If the entity is not
|
||||
|
@ -19,7 +19,7 @@ from .. import unittest
|
||||
|
||||
|
||||
class StringUtilsTestCase(unittest.TestCase):
|
||||
def test_client_secret_regex(self):
|
||||
def test_client_secret_regex(self) -> None:
|
||||
"""Ensure that client_secret does not contain illegal characters"""
|
||||
good = [
|
||||
"abcde12345",
|
||||
@ -46,7 +46,7 @@ class StringUtilsTestCase(unittest.TestCase):
|
||||
with self.assertRaises(SynapseError):
|
||||
assert_valid_client_secret(client_secret)
|
||||
|
||||
def test_base62_encode(self):
|
||||
def test_base62_encode(self) -> None:
|
||||
self.assertEqual("0", base62_encode(0))
|
||||
self.assertEqual("10", base62_encode(62))
|
||||
self.assertEqual("1c", base62_encode(100))
|
||||
|
@ -18,31 +18,31 @@ from tests.unittest import HomeserverTestCase
|
||||
|
||||
|
||||
class CanonicaliseEmailTests(HomeserverTestCase):
|
||||
def test_no_at(self):
|
||||
def test_no_at(self) -> None:
|
||||
with self.assertRaises(ValueError):
|
||||
canonicalise_email("address-without-at.bar")
|
||||
|
||||
def test_two_at(self):
|
||||
def test_two_at(self) -> None:
|
||||
with self.assertRaises(ValueError):
|
||||
canonicalise_email("foo@foo@test.bar")
|
||||
|
||||
def test_bad_format(self):
|
||||
def test_bad_format(self) -> None:
|
||||
with self.assertRaises(ValueError):
|
||||
canonicalise_email("user@bad.example.net@good.example.com")
|
||||
|
||||
def test_valid_format(self):
|
||||
def test_valid_format(self) -> None:
|
||||
self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar")
|
||||
|
||||
def test_domain_to_lower(self):
|
||||
def test_domain_to_lower(self) -> None:
|
||||
self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar")
|
||||
|
||||
def test_domain_with_umlaut(self):
|
||||
def test_domain_with_umlaut(self) -> None:
|
||||
self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com")
|
||||
|
||||
def test_address_casefold(self):
|
||||
def test_address_casefold(self) -> None:
|
||||
self.assertEqual(
|
||||
canonicalise_email("Strauß@Example.com"), "strauss@example.com"
|
||||
)
|
||||
|
||||
def test_address_trim(self):
|
||||
def test_address_trim(self) -> None:
|
||||
self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar")
|
||||
|
@ -19,7 +19,7 @@ from .. import unittest
|
||||
|
||||
|
||||
class TreeCacheTestCase(unittest.TestCase):
|
||||
def test_get_set_onelevel(self):
|
||||
def test_get_set_onelevel(self) -> None:
|
||||
cache = TreeCache()
|
||||
cache[("a",)] = "A"
|
||||
cache[("b",)] = "B"
|
||||
@ -27,7 +27,7 @@ class TreeCacheTestCase(unittest.TestCase):
|
||||
self.assertEqual(cache.get(("b",)), "B")
|
||||
self.assertEqual(len(cache), 2)
|
||||
|
||||
def test_pop_onelevel(self):
|
||||
def test_pop_onelevel(self) -> None:
|
||||
cache = TreeCache()
|
||||
cache[("a",)] = "A"
|
||||
cache[("b",)] = "B"
|
||||
@ -36,7 +36,7 @@ class TreeCacheTestCase(unittest.TestCase):
|
||||
self.assertEqual(cache.get(("b",)), "B")
|
||||
self.assertEqual(len(cache), 1)
|
||||
|
||||
def test_get_set_twolevel(self):
|
||||
def test_get_set_twolevel(self) -> None:
|
||||
cache = TreeCache()
|
||||
cache[("a", "a")] = "AA"
|
||||
cache[("a", "b")] = "AB"
|
||||
@ -46,7 +46,7 @@ class TreeCacheTestCase(unittest.TestCase):
|
||||
self.assertEqual(cache.get(("b", "a")), "BA")
|
||||
self.assertEqual(len(cache), 3)
|
||||
|
||||
def test_pop_twolevel(self):
|
||||
def test_pop_twolevel(self) -> None:
|
||||
cache = TreeCache()
|
||||
cache[("a", "a")] = "AA"
|
||||
cache[("a", "b")] = "AB"
|
||||
@ -58,7 +58,7 @@ class TreeCacheTestCase(unittest.TestCase):
|
||||
self.assertEqual(cache.pop(("b", "a")), None)
|
||||
self.assertEqual(len(cache), 1)
|
||||
|
||||
def test_pop_mixedlevel(self):
|
||||
def test_pop_mixedlevel(self) -> None:
|
||||
cache = TreeCache()
|
||||
cache[("a", "a")] = "AA"
|
||||
cache[("a", "b")] = "AB"
|
||||
@ -72,14 +72,14 @@ class TreeCacheTestCase(unittest.TestCase):
|
||||
|
||||
self.assertEqual({"AA", "AB"}, set(iterate_tree_cache_entry(popped)))
|
||||
|
||||
def test_clear(self):
|
||||
def test_clear(self) -> None:
|
||||
cache = TreeCache()
|
||||
cache[("a",)] = "A"
|
||||
cache[("b",)] = "B"
|
||||
cache.clear()
|
||||
self.assertEqual(len(cache), 0)
|
||||
|
||||
def test_contains(self):
|
||||
def test_contains(self) -> None:
|
||||
cache = TreeCache()
|
||||
cache[("a",)] = "A"
|
||||
self.assertTrue(("a",) in cache)
|
||||
|
@ -18,8 +18,8 @@ from .. import unittest
|
||||
|
||||
|
||||
class WheelTimerTestCase(unittest.TestCase):
|
||||
def test_single_insert_fetch(self):
|
||||
wheel = WheelTimer(bucket_size=5)
|
||||
def test_single_insert_fetch(self) -> None:
|
||||
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
|
||||
|
||||
obj = object()
|
||||
wheel.insert(100, obj, 150)
|
||||
@ -32,8 +32,8 @@ class WheelTimerTestCase(unittest.TestCase):
|
||||
self.assertListEqual(wheel.fetch(156), [obj])
|
||||
self.assertListEqual(wheel.fetch(170), [])
|
||||
|
||||
def test_multi_insert(self):
|
||||
wheel = WheelTimer(bucket_size=5)
|
||||
def test_multi_insert(self) -> None:
|
||||
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
|
||||
|
||||
obj1 = object()
|
||||
obj2 = object()
|
||||
@ -50,15 +50,15 @@ class WheelTimerTestCase(unittest.TestCase):
|
||||
self.assertListEqual(wheel.fetch(200), [obj3])
|
||||
self.assertListEqual(wheel.fetch(210), [])
|
||||
|
||||
def test_insert_past(self):
|
||||
wheel = WheelTimer(bucket_size=5)
|
||||
def test_insert_past(self) -> None:
|
||||
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
|
||||
|
||||
obj = object()
|
||||
wheel.insert(100, obj, 50)
|
||||
self.assertListEqual(wheel.fetch(120), [obj])
|
||||
|
||||
def test_insert_past_multi(self):
|
||||
wheel = WheelTimer(bucket_size=5)
|
||||
def test_insert_past_multi(self) -> None:
|
||||
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
|
||||
|
||||
obj1 = object()
|
||||
obj2 = object()
|
||||
|
Loading…
Reference in New Issue
Block a user