Add missing type hints to test.util.caches (#14529)

This commit is contained in:
Patrick Cloke 2022-11-22 17:35:54 -05:00 committed by GitHub
parent 7f78b383ca
commit 4ae967cf63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 76 additions and 66 deletions

1
changelog.d/14529.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints.

View File

@ -59,11 +59,6 @@ exclude = (?x)
|tests/server_notices/test_resource_limits_server_notices.py |tests/server_notices/test_resource_limits_server_notices.py
|tests/test_state.py |tests/test_state.py
|tests/test_terms_auth.py |tests/test_terms_auth.py
|tests/util/caches/test_cached_call.py
|tests/util/caches/test_deferred_cache.py
|tests/util/caches/test_descriptors.py
|tests/util/caches/test_response_cache.py
|tests/util/caches/test_ttlcache.py
|tests/util/test_async_helpers.py |tests/util/test_async_helpers.py
|tests/util/test_batching_queue.py |tests/util/test_batching_queue.py
|tests/util/test_dict_cache.py |tests/util/test_dict_cache.py
@ -133,6 +128,12 @@ disallow_untyped_defs = True
[mypy-tests.federation.transport.test_client] [mypy-tests.federation.transport.test_client]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.util.caches.*]
disallow_untyped_defs = True
[mypy-tests.util.caches.test_descriptors]
disallow_untyped_defs = False
[mypy-tests.utils] [mypy-tests.utils]
disallow_untyped_defs = True disallow_untyped_defs = True

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import NoReturn
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet import defer from twisted.internet import defer
@ -23,14 +24,14 @@ from tests.unittest import TestCase
class CachedCallTestCase(TestCase): class CachedCallTestCase(TestCase):
def test_get(self): def test_get(self) -> None:
""" """
Happy-path test case: makes a couple of calls and makes sure they behave Happy-path test case: makes a couple of calls and makes sure they behave
correctly correctly
""" """
d = Deferred() d: "Deferred[int]" = Deferred()
async def f(): async def f() -> int:
return await d return await d
slow_call = Mock(side_effect=f) slow_call = Mock(side_effect=f)
@ -43,7 +44,7 @@ class CachedCallTestCase(TestCase):
# now fire off a couple of calls # now fire off a couple of calls
completed_results = [] completed_results = []
async def r(): async def r() -> None:
res = await cached_call.get() res = await cached_call.get()
completed_results.append(res) completed_results.append(res)
@ -69,12 +70,12 @@ class CachedCallTestCase(TestCase):
self.assertEqual(r3, 123) self.assertEqual(r3, 123)
slow_call.assert_not_called() slow_call.assert_not_called()
def test_fast_call(self): def test_fast_call(self) -> None:
""" """
Test the behaviour when the underlying function completes immediately Test the behaviour when the underlying function completes immediately
""" """
async def f(): async def f() -> int:
return 12 return 12
fast_call = Mock(side_effect=f) fast_call = Mock(side_effect=f)
@ -92,12 +93,12 @@ class CachedCallTestCase(TestCase):
class RetryOnExceptionCachedCallTestCase(TestCase): class RetryOnExceptionCachedCallTestCase(TestCase):
def test_get(self): def test_get(self) -> None:
# set up the RetryOnExceptionCachedCall around a function which will fail # set up the RetryOnExceptionCachedCall around a function which will fail
# (after a while) # (after a while)
d = Deferred() d: "Deferred[int]" = Deferred()
async def f1(): async def f1() -> NoReturn:
await d await d
raise ValueError("moo") raise ValueError("moo")
@ -110,7 +111,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase):
# now fire off a couple of calls # now fire off a couple of calls
completed_results = [] completed_results = []
async def r(): async def r() -> None:
try: try:
await cached_call.get() await cached_call.get()
except Exception as e1: except Exception as e1:
@ -137,7 +138,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase):
# to the getter # to the getter
d = Deferred() d = Deferred()
async def f2(): async def f2() -> int:
return await d return await d
slow_call.reset_mock() slow_call.reset_mock()

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from functools import partial from functools import partial
from typing import List, Tuple
from twisted.internet import defer from twisted.internet import defer
@ -22,20 +23,20 @@ from tests.unittest import TestCase
class DeferredCacheTestCase(TestCase): class DeferredCacheTestCase(TestCase):
def test_empty(self): def test_empty(self) -> None:
cache = DeferredCache("test") cache: DeferredCache[str, int] = DeferredCache("test")
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
cache.get("foo") cache.get("foo")
def test_hit(self): def test_hit(self) -> None:
cache = DeferredCache("test") cache: DeferredCache[str, int] = DeferredCache("test")
cache.prefill("foo", 123) cache.prefill("foo", 123)
self.assertEqual(self.successResultOf(cache.get("foo")), 123) self.assertEqual(self.successResultOf(cache.get("foo")), 123)
def test_hit_deferred(self): def test_hit_deferred(self) -> None:
cache = DeferredCache("test") cache: DeferredCache[str, int] = DeferredCache("test")
origin_d = defer.Deferred() origin_d: "defer.Deferred[int]" = defer.Deferred()
set_d = cache.set("k1", origin_d) set_d = cache.set("k1", origin_d)
# get should return an incomplete deferred # get should return an incomplete deferred
@ -43,7 +44,7 @@ class DeferredCacheTestCase(TestCase):
self.assertFalse(get_d.called) self.assertFalse(get_d.called)
# add a callback that will make sure that the set_d gets called before the get_d # add a callback that will make sure that the set_d gets called before the get_d
def check1(r): def check1(r: str) -> str:
self.assertTrue(set_d.called) self.assertTrue(set_d.called)
return r return r
@ -55,16 +56,16 @@ class DeferredCacheTestCase(TestCase):
self.assertEqual(self.successResultOf(set_d), 99) self.assertEqual(self.successResultOf(set_d), 99)
self.assertEqual(self.successResultOf(get_d), 99) self.assertEqual(self.successResultOf(get_d), 99)
def test_callbacks(self): def test_callbacks(self) -> None:
"""Invalidation callbacks are called at the right time""" """Invalidation callbacks are called at the right time"""
cache = DeferredCache("test") cache: DeferredCache[str, int] = DeferredCache("test")
callbacks = set() callbacks = set()
# start with an entry, with a callback # start with an entry, with a callback
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill")) cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
# now replace that entry with a pending result # now replace that entry with a pending result
origin_d = defer.Deferred() origin_d: "defer.Deferred[int]" = defer.Deferred()
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set")) set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
# ... and also make a get request # ... and also make a get request
@ -89,15 +90,15 @@ class DeferredCacheTestCase(TestCase):
cache.prefill("k1", 30) cache.prefill("k1", 30)
self.assertEqual(callbacks, {"set", "get"}) self.assertEqual(callbacks, {"set", "get"})
def test_set_fail(self): def test_set_fail(self) -> None:
cache = DeferredCache("test") cache: DeferredCache[str, int] = DeferredCache("test")
callbacks = set() callbacks = set()
# start with an entry, with a callback # start with an entry, with a callback
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill")) cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
# now replace that entry with a pending result # now replace that entry with a pending result
origin_d = defer.Deferred() origin_d: defer.Deferred = defer.Deferred()
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set")) set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
# ... and also make a get request # ... and also make a get request
@ -126,9 +127,9 @@ class DeferredCacheTestCase(TestCase):
cache.prefill("k1", 30) cache.prefill("k1", 30)
self.assertEqual(callbacks, {"prefill", "get2"}) self.assertEqual(callbacks, {"prefill", "get2"})
def test_get_immediate(self): def test_get_immediate(self) -> None:
cache = DeferredCache("test") cache: DeferredCache[str, int] = DeferredCache("test")
d1 = defer.Deferred() d1: "defer.Deferred[int]" = defer.Deferred()
cache.set("key1", d1) cache.set("key1", d1)
# get_immediate should return default # get_immediate should return default
@ -142,27 +143,27 @@ class DeferredCacheTestCase(TestCase):
v = cache.get_immediate("key1", 1) v = cache.get_immediate("key1", 1)
self.assertEqual(v, 2) self.assertEqual(v, 2)
def test_invalidate(self): def test_invalidate(self) -> None:
cache = DeferredCache("test") cache: DeferredCache[Tuple[str], int] = DeferredCache("test")
cache.prefill(("foo",), 123) cache.prefill(("foo",), 123)
cache.invalidate(("foo",)) cache.invalidate(("foo",))
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
cache.get(("foo",)) cache.get(("foo",))
def test_invalidate_all(self): def test_invalidate_all(self) -> None:
cache = DeferredCache("testcache") cache: DeferredCache[str, str] = DeferredCache("testcache")
callback_record = [False, False] callback_record = [False, False]
def record_callback(idx): def record_callback(idx: int) -> None:
callback_record[idx] = True callback_record[idx] = True
# add a couple of pending entries # add a couple of pending entries
d1 = defer.Deferred() d1: "defer.Deferred[str]" = defer.Deferred()
cache.set("key1", d1, partial(record_callback, 0)) cache.set("key1", d1, partial(record_callback, 0))
d2 = defer.Deferred() d2: "defer.Deferred[str]" = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1)) cache.set("key2", d2, partial(record_callback, 1))
# lookup should return pending deferreds # lookup should return pending deferreds
@ -193,8 +194,8 @@ class DeferredCacheTestCase(TestCase):
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
cache.get("key1", None) cache.get("key1", None)
def test_eviction(self): def test_eviction(self) -> None:
cache = DeferredCache( cache: DeferredCache[int, str] = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False "test", max_entries=2, apply_cache_factor_from_config=False
) )
@ -208,8 +209,8 @@ class DeferredCacheTestCase(TestCase):
cache.get(2) cache.get(2)
cache.get(3) cache.get(3)
def test_eviction_lru(self): def test_eviction_lru(self) -> None:
cache = DeferredCache( cache: DeferredCache[int, str] = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False "test", max_entries=2, apply_cache_factor_from_config=False
) )
@ -227,8 +228,8 @@ class DeferredCacheTestCase(TestCase):
cache.get(1) cache.get(1)
cache.get(3) cache.get(3)
def test_eviction_iterable(self): def test_eviction_iterable(self) -> None:
cache = DeferredCache( cache: DeferredCache[int, List[str]] = DeferredCache(
"test", "test",
max_entries=3, max_entries=3,
apply_cache_factor_from_config=False, apply_cache_factor_from_config=False,

View File

@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Iterable, Set, Tuple from typing import Iterable, Set, Tuple, cast
from unittest import mock from unittest import mock
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError, Deferred from twisted.internet.defer import CancelledError, Deferred
from twisted.internet.interfaces import IReactorTime
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.logging.context import ( from synapse.logging.context import (
@ -37,8 +38,8 @@ logger = logging.getLogger(__name__)
def run_on_reactor(): def run_on_reactor():
d = defer.Deferred() d: "Deferred[int]" = defer.Deferred()
reactor.callLater(0, d.callback, 0) cast(IReactorTime, reactor).callLater(0, d.callback, 0)
return make_deferred_yieldable(d) return make_deferred_yieldable(d)
@ -224,7 +225,8 @@ class DescriptorTestCase(unittest.TestCase):
callbacks: Set[str] = set() callbacks: Set[str] = set()
# set off an asynchronous request # set off an asynchronous request
obj.result = origin_d = defer.Deferred() origin_d: Deferred = defer.Deferred()
obj.result = origin_d
d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1")) d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
self.assertFalse(d1.called) self.assertFalse(d1.called)
@ -262,7 +264,7 @@ class DescriptorTestCase(unittest.TestCase):
"""Check that logcontexts are set and restored correctly when """Check that logcontexts are set and restored correctly when
using the cache.""" using the cache."""
complete_lookup = defer.Deferred() complete_lookup: Deferred = defer.Deferred()
class Cls: class Cls:
@descriptors.cached() @descriptors.cached()
@ -772,10 +774,14 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList(cached_method_name="fn", list_name="args1") @descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1, arg2): async def list_fn(self, args1, arg2):
assert current_context().name == "c1" context = current_context()
assert isinstance(context, LoggingContext)
assert context.name == "c1"
# we want this to behave like an asynchronous function # we want this to behave like an asynchronous function
await run_on_reactor() await run_on_reactor()
assert current_context().name == "c1" context = current_context()
assert isinstance(context, LoggingContext)
assert context.name == "c1"
return self.mock(args1, arg2) return self.mock(args1, arg2)
with LoggingContext("c1") as c1: with LoggingContext("c1") as c1:
@ -834,7 +840,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
return self.mock(args1) return self.mock(args1)
obj = Cls() obj = Cls()
deferred_result = Deferred() deferred_result: "Deferred[dict]" = Deferred()
obj.mock.return_value = deferred_result obj.mock.return_value = deferred_result
# start off several concurrent lookups of the same key # start off several concurrent lookups of the same key

View File

@ -35,7 +35,7 @@ class ResponseCacheTestCase(TestCase):
(These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock) (These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock)
""" """
def setUp(self): def setUp(self) -> None:
self.reactor, self.clock = get_clock() self.reactor, self.clock = get_clock()
def with_cache(self, name: str, ms: int = 0) -> ResponseCache: def with_cache(self, name: str, ms: int = 0) -> ResponseCache:
@ -49,7 +49,7 @@ class ResponseCacheTestCase(TestCase):
await self.clock.sleep(1) await self.clock.sleep(1)
return o return o
def test_cache_hit(self): def test_cache_hit(self) -> None:
cache = self.with_cache("keeping_cache", ms=9001) cache = self.with_cache("keeping_cache", ms=9001)
expected_result = "howdy" expected_result = "howdy"
@ -74,7 +74,7 @@ class ResponseCacheTestCase(TestCase):
"cache should still have the result", "cache should still have the result",
) )
def test_cache_miss(self): def test_cache_miss(self) -> None:
cache = self.with_cache("trashing_cache", ms=0) cache = self.with_cache("trashing_cache", ms=0)
expected_result = "howdy" expected_result = "howdy"
@ -90,7 +90,7 @@ class ResponseCacheTestCase(TestCase):
) )
self.assertCountEqual([], cache.keys(), "cache should not have the result now") self.assertCountEqual([], cache.keys(), "cache should not have the result now")
def test_cache_expire(self): def test_cache_expire(self) -> None:
cache = self.with_cache("short_cache", ms=1000) cache = self.with_cache("short_cache", ms=1000)
expected_result = "howdy" expected_result = "howdy"
@ -115,7 +115,7 @@ class ResponseCacheTestCase(TestCase):
self.reactor.pump((2,)) self.reactor.pump((2,))
self.assertCountEqual([], cache.keys(), "cache should not have the result now") self.assertCountEqual([], cache.keys(), "cache should not have the result now")
def test_cache_wait_hit(self): def test_cache_wait_hit(self) -> None:
cache = self.with_cache("neutral_cache") cache = self.with_cache("neutral_cache")
expected_result = "howdy" expected_result = "howdy"
@ -131,7 +131,7 @@ class ResponseCacheTestCase(TestCase):
self.assertEqual(expected_result, self.successResultOf(wrap_d)) self.assertEqual(expected_result, self.successResultOf(wrap_d))
def test_cache_wait_expire(self): def test_cache_wait_expire(self) -> None:
cache = self.with_cache("medium_cache", ms=3000) cache = self.with_cache("medium_cache", ms=3000)
expected_result = "howdy" expected_result = "howdy"
@ -162,7 +162,7 @@ class ResponseCacheTestCase(TestCase):
self.assertCountEqual([], cache.keys(), "cache should not have the result now") self.assertCountEqual([], cache.keys(), "cache should not have the result now")
@parameterized.expand([(True,), (False,)]) @parameterized.expand([(True,), (False,)])
def test_cache_context_nocache(self, should_cache: bool): def test_cache_context_nocache(self, should_cache: bool) -> None:
"""If the callback clears the should_cache bit, the result should not be cached""" """If the callback clears the should_cache bit, the result should not be cached"""
cache = self.with_cache("medium_cache", ms=3000) cache = self.with_cache("medium_cache", ms=3000)
@ -170,7 +170,7 @@ class ResponseCacheTestCase(TestCase):
call_count = 0 call_count = 0
async def non_caching(o: str, cache_context: ResponseCacheContext[int]): async def non_caching(o: str, cache_context: ResponseCacheContext[int]) -> str:
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
await self.clock.sleep(1) await self.clock.sleep(1)

View File

@ -20,11 +20,11 @@ from tests import unittest
class CacheTestCase(unittest.TestCase): class CacheTestCase(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
self.mock_timer = Mock(side_effect=lambda: 100.0) self.mock_timer = Mock(side_effect=lambda: 100.0)
self.cache = TTLCache("test_cache", self.mock_timer) self.cache: TTLCache[str, str] = TTLCache("test_cache", self.mock_timer)
def test_get(self): def test_get(self) -> None:
"""simple set/get tests""" """simple set/get tests"""
self.cache.set("one", "1", 10) self.cache.set("one", "1", 10)
self.cache.set("two", "2", 20) self.cache.set("two", "2", 20)
@ -59,7 +59,7 @@ class CacheTestCase(unittest.TestCase):
self.assertEqual(self.cache._metrics.hits, 4) self.assertEqual(self.cache._metrics.hits, 4)
self.assertEqual(self.cache._metrics.misses, 5) self.assertEqual(self.cache._metrics.misses, 5)
def test_expiry(self): def test_expiry(self) -> None:
self.cache.set("one", "1", 10) self.cache.set("one", "1", 10)
self.cache.set("two", "2", 20) self.cache.set("two", "2", 20)
self.cache.set("three", "3", 30) self.cache.set("three", "3", 30)