mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
More types for synapse.util, part 1 (#10888)
The following modules now pass `disallow_untyped_defs`: * synapse.util.caches.cached_call * synapse.util.caches.lrucache * synapse.util.caches.response_cache * synapse.util.caches.stream_change_cache * synapse.util.caches.ttlcache pass * synapse.util.daemonize * synapse.util.patch_inline_callbacks pass `no-untyped-defs` * synapse.util.versionstring Additional typing in synapse.util.metrics. Didn't get this to pass `no-untyped-defs`, think I'll need to watch #10847
This commit is contained in:
parent
6744273f0b
commit
f8d0f72b27
1
changelog.d/10888.misc
Normal file
1
changelog.d/10888.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Improve type hinting in `synapse.util`.
|
24
mypy.ini
24
mypy.ini
@ -102,9 +102,27 @@ disallow_untyped_defs = True
|
|||||||
[mypy-synapse.util.batching_queue]
|
[mypy-synapse.util.batching_queue]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.caches.cached_call]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.util.caches.dictionary_cache]
|
[mypy-synapse.util.caches.dictionary_cache]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.caches.lrucache]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.caches.response_cache]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.caches.stream_change_cache]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.caches.ttl_cache]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.daemonize]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.util.file_consumer]
|
[mypy-synapse.util.file_consumer]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
@ -141,6 +159,9 @@ disallow_untyped_defs = True
|
|||||||
[mypy-synapse.util.msisdn]
|
[mypy-synapse.util.msisdn]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.patch_inline_callbacks]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.util.ratelimitutils]
|
[mypy-synapse.util.ratelimitutils]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
@ -162,6 +183,9 @@ disallow_untyped_defs = True
|
|||||||
[mypy-synapse.util.wheel_timer]
|
[mypy-synapse.util.wheel_timer]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.versionstring]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-tests.handlers.test_user_directory]
|
[mypy-tests.handlers.test_user_directory]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ class CachedCall(Generic[TV]):
|
|||||||
# result in the deferred, since `awaiting` a deferred destroys its result.
|
# result in the deferred, since `awaiting` a deferred destroys its result.
|
||||||
# (Also, if it's a Failure, GCing the deferred would log a critical error
|
# (Also, if it's a Failure, GCing the deferred would log a critical error
|
||||||
# about unhandled Failures)
|
# about unhandled Failures)
|
||||||
def got_result(r):
|
def got_result(r: Union[TV, Failure]) -> None:
|
||||||
self._result = r
|
self._result = r
|
||||||
|
|
||||||
self._deferred.addBoth(got_result)
|
self._deferred.addBoth(got_result)
|
||||||
|
@ -31,6 +31,7 @@ from prometheus_client import Gauge
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.python import failure
|
from twisted.python import failure
|
||||||
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
from synapse.util.async_helpers import ObservableDeferred
|
from synapse.util.async_helpers import ObservableDeferred
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
@ -112,7 +113,7 @@ class DeferredCache(Generic[KT, VT]):
|
|||||||
self.thread: Optional[threading.Thread] = None
|
self.thread: Optional[threading.Thread] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_entries(self):
|
def max_entries(self) -> int:
|
||||||
return self.cache.max_size
|
return self.cache.max_size
|
||||||
|
|
||||||
def check_thread(self) -> None:
|
def check_thread(self) -> None:
|
||||||
@ -258,7 +259,7 @@ class DeferredCache(Generic[KT, VT]):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def cb(result) -> None:
|
def cb(result: VT) -> None:
|
||||||
if compare_and_pop():
|
if compare_and_pop():
|
||||||
self.cache.set(key, result, entry.callbacks)
|
self.cache.set(key, result, entry.callbacks)
|
||||||
else:
|
else:
|
||||||
@ -270,7 +271,7 @@ class DeferredCache(Generic[KT, VT]):
|
|||||||
# not have been. Either way, let's double-check now.
|
# not have been. Either way, let's double-check now.
|
||||||
entry.invalidate()
|
entry.invalidate()
|
||||||
|
|
||||||
def eb(_fail) -> None:
|
def eb(_fail: Failure) -> None:
|
||||||
compare_and_pop()
|
compare_and_pop()
|
||||||
entry.invalidate()
|
entry.invalidate()
|
||||||
|
|
||||||
@ -284,11 +285,11 @@ class DeferredCache(Generic[KT, VT]):
|
|||||||
|
|
||||||
def prefill(
|
def prefill(
|
||||||
self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
|
self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
|
||||||
):
|
) -> None:
|
||||||
callbacks = [callback] if callback else []
|
callbacks = [callback] if callback else []
|
||||||
self.cache.set(key, value, callbacks=callbacks)
|
self.cache.set(key, value, callbacks=callbacks)
|
||||||
|
|
||||||
def invalidate(self, key):
|
def invalidate(self, key) -> None:
|
||||||
"""Delete a key, or tree of entries
|
"""Delete a key, or tree of entries
|
||||||
|
|
||||||
If the cache is backed by a regular dict, then "key" must be of
|
If the cache is backed by a regular dict, then "key" must be of
|
||||||
|
@ -52,7 +52,7 @@ logger = logging.getLogger(__name__)
|
|||||||
try:
|
try:
|
||||||
from pympler.asizeof import Asizer
|
from pympler.asizeof import Asizer
|
||||||
|
|
||||||
def _get_size_of(val: Any, *, recurse=True) -> int:
|
def _get_size_of(val: Any, *, recurse: bool = True) -> int:
|
||||||
"""Get an estimate of the size in bytes of the object.
|
"""Get an estimate of the size in bytes of the object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -71,7 +71,7 @@ try:
|
|||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
||||||
def _get_size_of(val: Any, *, recurse=True) -> int:
|
def _get_size_of(val: Any, *, recurse: bool = True) -> int:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
@ -85,15 +85,6 @@ VT = TypeVar("VT")
|
|||||||
# a general type var, distinct from either KT or VT
|
# a general type var, distinct from either KT or VT
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def enumerate_leaves(node, depth):
|
|
||||||
if depth == 0:
|
|
||||||
yield node
|
|
||||||
else:
|
|
||||||
for n in node.values():
|
|
||||||
yield from enumerate_leaves(n, depth - 1)
|
|
||||||
|
|
||||||
|
|
||||||
P = TypeVar("P")
|
P = TypeVar("P")
|
||||||
|
|
||||||
|
|
||||||
@ -102,7 +93,7 @@ class _TimedListNode(ListNode[P]):
|
|||||||
|
|
||||||
__slots__ = ["last_access_ts_secs"]
|
__slots__ = ["last_access_ts_secs"]
|
||||||
|
|
||||||
def update_last_access(self, clock: Clock):
|
def update_last_access(self, clock: Clock) -> None:
|
||||||
self.last_access_ts_secs = int(clock.time())
|
self.last_access_ts_secs = int(clock.time())
|
||||||
|
|
||||||
|
|
||||||
@ -115,7 +106,7 @@ GLOBAL_ROOT = ListNode["_Node"].create_root_node()
|
|||||||
|
|
||||||
|
|
||||||
@wrap_as_background_process("LruCache._expire_old_entries")
|
@wrap_as_background_process("LruCache._expire_old_entries")
|
||||||
async def _expire_old_entries(clock: Clock, expiry_seconds: int):
|
async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
|
||||||
"""Walks the global cache list to find cache entries that haven't been
|
"""Walks the global cache list to find cache entries that haven't been
|
||||||
accessed in the given number of seconds.
|
accessed in the given number of seconds.
|
||||||
"""
|
"""
|
||||||
@ -163,7 +154,7 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int):
|
|||||||
logger.info("Dropped %d items from caches", i)
|
logger.info("Dropped %d items from caches", i)
|
||||||
|
|
||||||
|
|
||||||
def setup_expire_lru_cache_entries(hs: "HomeServer"):
|
def setup_expire_lru_cache_entries(hs: "HomeServer") -> None:
|
||||||
"""Start a background job that expires all cache entries if they have not
|
"""Start a background job that expires all cache entries if they have not
|
||||||
been accessed for the given number of seconds.
|
been accessed for the given number of seconds.
|
||||||
"""
|
"""
|
||||||
@ -183,7 +174,7 @@ def setup_expire_lru_cache_entries(hs: "HomeServer"):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class _Node:
|
class _Node(Generic[KT, VT]):
|
||||||
__slots__ = [
|
__slots__ = [
|
||||||
"_list_node",
|
"_list_node",
|
||||||
"_global_list_node",
|
"_global_list_node",
|
||||||
@ -197,8 +188,8 @@ class _Node:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root: "ListNode[_Node]",
|
root: "ListNode[_Node]",
|
||||||
key,
|
key: KT,
|
||||||
value,
|
value: VT,
|
||||||
cache: "weakref.ReferenceType[LruCache]",
|
cache: "weakref.ReferenceType[LruCache]",
|
||||||
clock: Clock,
|
clock: Clock,
|
||||||
callbacks: Collection[Callable[[], None]] = (),
|
callbacks: Collection[Callable[[], None]] = (),
|
||||||
@ -409,7 +400,7 @@ class LruCache(Generic[KT, VT]):
|
|||||||
|
|
||||||
def synchronized(f: FT) -> FT:
|
def synchronized(f: FT) -> FT:
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def inner(*args, **kwargs):
|
def inner(*args: Any, **kwargs: Any) -> Any:
|
||||||
with lock:
|
with lock:
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
@ -418,17 +409,19 @@ class LruCache(Generic[KT, VT]):
|
|||||||
cached_cache_len = [0]
|
cached_cache_len = [0]
|
||||||
if size_callback is not None:
|
if size_callback is not None:
|
||||||
|
|
||||||
def cache_len():
|
def cache_len() -> int:
|
||||||
return cached_cache_len[0]
|
return cached_cache_len[0]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def cache_len():
|
def cache_len() -> int:
|
||||||
return len(cache)
|
return len(cache)
|
||||||
|
|
||||||
self.len = synchronized(cache_len)
|
self.len = synchronized(cache_len)
|
||||||
|
|
||||||
def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
|
def add_node(
|
||||||
|
key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
|
||||||
|
) -> None:
|
||||||
node = _Node(
|
node = _Node(
|
||||||
list_root,
|
list_root,
|
||||||
key,
|
key,
|
||||||
@ -446,7 +439,7 @@ class LruCache(Generic[KT, VT]):
|
|||||||
if caches.TRACK_MEMORY_USAGE and metrics:
|
if caches.TRACK_MEMORY_USAGE and metrics:
|
||||||
metrics.inc_memory_usage(node.memory)
|
metrics.inc_memory_usage(node.memory)
|
||||||
|
|
||||||
def move_node_to_front(node: _Node):
|
def move_node_to_front(node: _Node) -> None:
|
||||||
node.move_to_front(real_clock, list_root)
|
node.move_to_front(real_clock, list_root)
|
||||||
|
|
||||||
def delete_node(node: _Node) -> int:
|
def delete_node(node: _Node) -> int:
|
||||||
@ -488,7 +481,7 @@ class LruCache(Generic[KT, VT]):
|
|||||||
default: Optional[T] = None,
|
default: Optional[T] = None,
|
||||||
callbacks: Collection[Callable[[], None]] = (),
|
callbacks: Collection[Callable[[], None]] = (),
|
||||||
update_metrics: bool = True,
|
update_metrics: bool = True,
|
||||||
):
|
) -> Union[None, T, VT]:
|
||||||
node = cache.get(key, None)
|
node = cache.get(key, None)
|
||||||
if node is not None:
|
if node is not None:
|
||||||
move_node_to_front(node)
|
move_node_to_front(node)
|
||||||
@ -502,7 +495,9 @@ class LruCache(Generic[KT, VT]):
|
|||||||
return default
|
return default
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()):
|
def cache_set(
|
||||||
|
key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()
|
||||||
|
) -> None:
|
||||||
node = cache.get(key, None)
|
node = cache.get(key, None)
|
||||||
if node is not None:
|
if node is not None:
|
||||||
# We sometimes store large objects, e.g. dicts, which cause
|
# We sometimes store large objects, e.g. dicts, which cause
|
||||||
@ -547,7 +542,7 @@ class LruCache(Generic[KT, VT]):
|
|||||||
...
|
...
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def cache_pop(key: KT, default: Optional[T] = None):
|
def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]:
|
||||||
node = cache.get(key, None)
|
node = cache.get(key, None)
|
||||||
if node:
|
if node:
|
||||||
delete_node(node)
|
delete_node(node)
|
||||||
@ -612,25 +607,25 @@ class LruCache(Generic[KT, VT]):
|
|||||||
self.contains = cache_contains
|
self.contains = cache_contains
|
||||||
self.clear = cache_clear
|
self.clear = cache_clear
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key: KT) -> VT:
|
||||||
result = self.get(key, self.sentinel)
|
result = self.get(key, self.sentinel)
|
||||||
if result is self.sentinel:
|
if result is self.sentinel:
|
||||||
raise KeyError()
|
raise KeyError()
|
||||||
else:
|
else:
|
||||||
return result
|
return cast(VT, result)
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key: KT, value: VT) -> None:
|
||||||
self.set(key, value)
|
self.set(key, value)
|
||||||
|
|
||||||
def __delitem__(self, key, value):
|
def __delitem__(self, key: KT, value: VT) -> None:
|
||||||
result = self.pop(key, self.sentinel)
|
result = self.pop(key, self.sentinel)
|
||||||
if result is self.sentinel:
|
if result is self.sentinel:
|
||||||
raise KeyError()
|
raise KeyError()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return self.len()
|
return self.len()
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key: KT) -> bool:
|
||||||
return self.contains(key)
|
return self.contains(key)
|
||||||
|
|
||||||
def set_cache_factor(self, factor: float) -> bool:
|
def set_cache_factor(self, factor: float) -> bool:
|
||||||
|
@ -104,8 +104,8 @@ class ResponseCache(Generic[KV]):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _set(
|
def _set(
|
||||||
self, context: ResponseCacheContext[KV], deferred: defer.Deferred
|
self, context: ResponseCacheContext[KV], deferred: "defer.Deferred[RV]"
|
||||||
) -> defer.Deferred:
|
) -> "defer.Deferred[RV]":
|
||||||
"""Set the entry for the given key to the given deferred.
|
"""Set the entry for the given key to the given deferred.
|
||||||
|
|
||||||
*deferred* should run its callbacks in the sentinel logcontext (ie,
|
*deferred* should run its callbacks in the sentinel logcontext (ie,
|
||||||
@ -126,7 +126,7 @@ class ResponseCache(Generic[KV]):
|
|||||||
key = context.cache_key
|
key = context.cache_key
|
||||||
self.pending_result_cache[key] = result
|
self.pending_result_cache[key] = result
|
||||||
|
|
||||||
def on_complete(r):
|
def on_complete(r: RV) -> RV:
|
||||||
# if this cache has a non-zero timeout, and the callback has not cleared
|
# if this cache has a non-zero timeout, and the callback has not cleared
|
||||||
# the should_cache bit, we leave it in the cache for now and schedule
|
# the should_cache bit, we leave it in the cache for now and schedule
|
||||||
# its removal later.
|
# its removal later.
|
||||||
|
@ -40,10 +40,10 @@ class StreamChangeCache:
|
|||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
current_stream_pos: int,
|
current_stream_pos: int,
|
||||||
max_size=10000,
|
max_size: int = 10000,
|
||||||
prefilled_cache: Optional[Mapping[EntityType, int]] = None,
|
prefilled_cache: Optional[Mapping[EntityType, int]] = None,
|
||||||
):
|
) -> None:
|
||||||
self._original_max_size = max_size
|
self._original_max_size: int = max_size
|
||||||
self._max_size = math.floor(max_size)
|
self._max_size = math.floor(max_size)
|
||||||
self._entity_to_key: Dict[EntityType, int] = {}
|
self._entity_to_key: Dict[EntityType, int] = {}
|
||||||
|
|
||||||
|
@ -159,12 +159,12 @@ class TTLCache(Generic[KT, VT]):
|
|||||||
del self._expiry_list[0]
|
del self._expiry_list[0]
|
||||||
|
|
||||||
|
|
||||||
@attr.s(frozen=True, slots=True)
|
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||||
class _CacheEntry:
|
class _CacheEntry: # Should be Generic[KT, VT]. See python-attrs/attrs#313
|
||||||
"""TTLCache entry"""
|
"""TTLCache entry"""
|
||||||
|
|
||||||
# expiry_time is the first attribute, so that entries are sorted by expiry.
|
# expiry_time is the first attribute, so that entries are sorted by expiry.
|
||||||
expiry_time = attr.ib(type=float)
|
expiry_time: float
|
||||||
ttl = attr.ib(type=float)
|
ttl: float
|
||||||
key = attr.ib()
|
key: Any # should be KT
|
||||||
value = attr.ib()
|
value: Any # should be VT
|
||||||
|
@ -19,6 +19,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
from types import FrameType, TracebackType
|
||||||
|
from typing import NoReturn, Type
|
||||||
|
|
||||||
|
|
||||||
def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None:
|
def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None:
|
||||||
@ -97,7 +99,9 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
|
|||||||
# (we don't normally expect reactor.run to raise any exceptions, but this will
|
# (we don't normally expect reactor.run to raise any exceptions, but this will
|
||||||
# also catch any other uncaught exceptions before we get that far.)
|
# also catch any other uncaught exceptions before we get that far.)
|
||||||
|
|
||||||
def excepthook(type_, value, traceback):
|
def excepthook(
|
||||||
|
type_: Type[BaseException], value: BaseException, traceback: TracebackType
|
||||||
|
) -> None:
|
||||||
logger.critical("Unhanded exception", exc_info=(type_, value, traceback))
|
logger.critical("Unhanded exception", exc_info=(type_, value, traceback))
|
||||||
|
|
||||||
sys.excepthook = excepthook
|
sys.excepthook = excepthook
|
||||||
@ -119,7 +123,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# write a log line on SIGTERM.
|
# write a log line on SIGTERM.
|
||||||
def sigterm(signum, frame):
|
def sigterm(signum: signal.Signals, frame: FrameType) -> NoReturn:
|
||||||
logger.warning("Caught signal %s. Stopping daemon." % signum)
|
logger.warning("Caught signal %s. Stopping daemon." % signum)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
@ -14,9 +14,11 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable, Optional, TypeVar, cast
|
from types import TracebackType
|
||||||
|
from typing import Any, Callable, Optional, Type, TypeVar, cast
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import (
|
||||||
ContextResourceUsage,
|
ContextResourceUsage,
|
||||||
@ -24,6 +26,7 @@ from synapse.logging.context import (
|
|||||||
current_context,
|
current_context,
|
||||||
)
|
)
|
||||||
from synapse.metrics import InFlightGauge
|
from synapse.metrics import InFlightGauge
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -64,6 +67,10 @@ in_flight = InFlightGauge(
|
|||||||
T = TypeVar("T", bound=Callable[..., Any])
|
T = TypeVar("T", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
|
class HasClock(Protocol):
|
||||||
|
clock: Clock
|
||||||
|
|
||||||
|
|
||||||
def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
|
def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
|
||||||
"""
|
"""
|
||||||
Used to decorate an async function with a `Measure` context manager.
|
Used to decorate an async function with a `Measure` context manager.
|
||||||
@ -86,7 +93,7 @@ def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
|
|||||||
block_name = func.__name__ if name is None else name
|
block_name = func.__name__ if name is None else name
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def measured_func(self, *args, **kwargs):
|
async def measured_func(self: HasClock, *args: Any, **kwargs: Any) -> Any:
|
||||||
with Measure(self.clock, block_name):
|
with Measure(self.clock, block_name):
|
||||||
r = await func(self, *args, **kwargs)
|
r = await func(self, *args, **kwargs)
|
||||||
return r
|
return r
|
||||||
@ -104,10 +111,10 @@ class Measure:
|
|||||||
"start",
|
"start",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, clock, name: str):
|
def __init__(self, clock: Clock, name: str) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
clock: A n object with a "time()" method, which returns the current
|
clock: An object with a "time()" method, which returns the current
|
||||||
time in seconds.
|
time in seconds.
|
||||||
name: The name of the metric to report.
|
name: The name of the metric to report.
|
||||||
"""
|
"""
|
||||||
@ -124,7 +131,7 @@ class Measure:
|
|||||||
assert isinstance(curr_context, LoggingContext)
|
assert isinstance(curr_context, LoggingContext)
|
||||||
parent_context = curr_context
|
parent_context = curr_context
|
||||||
self._logging_context = LoggingContext(str(curr_context), parent_context)
|
self._logging_context = LoggingContext(str(curr_context), parent_context)
|
||||||
self.start: Optional[int] = None
|
self.start: Optional[float] = None
|
||||||
|
|
||||||
def __enter__(self) -> "Measure":
|
def __enter__(self) -> "Measure":
|
||||||
if self.start is not None:
|
if self.start is not None:
|
||||||
@ -138,7 +145,12 @@ class Measure:
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[Type[BaseException]],
|
||||||
|
exc_val: Optional[BaseException],
|
||||||
|
exc_tb: Optional[TracebackType],
|
||||||
|
) -> None:
|
||||||
if self.start is None:
|
if self.start is None:
|
||||||
raise RuntimeError("Measure() block exited without being entered")
|
raise RuntimeError("Measure() block exited without being entered")
|
||||||
|
|
||||||
@ -168,8 +180,9 @@ class Measure:
|
|||||||
"""
|
"""
|
||||||
return self._logging_context.get_resource_usage()
|
return self._logging_context.get_resource_usage()
|
||||||
|
|
||||||
def _update_in_flight(self, metrics):
|
def _update_in_flight(self, metrics) -> None:
|
||||||
"""Gets called when processing in flight metrics"""
|
"""Gets called when processing in flight metrics"""
|
||||||
|
assert self.start is not None
|
||||||
duration = self.clock.time() - self.start
|
duration = self.clock.time() - self.start
|
||||||
|
|
||||||
metrics.real_time_max = max(metrics.real_time_max, duration)
|
metrics.real_time_max = max(metrics.real_time_max, duration)
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Callable, List
|
from typing import Any, Callable, Generator, List, TypeVar
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
@ -24,6 +24,9 @@ from twisted.python.failure import Failure
|
|||||||
_already_patched = False
|
_already_patched = False
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def do_patch() -> None:
|
def do_patch() -> None:
|
||||||
"""
|
"""
|
||||||
Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
|
Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
|
||||||
@ -37,15 +40,19 @@ def do_patch() -> None:
|
|||||||
if _already_patched:
|
if _already_patched:
|
||||||
return
|
return
|
||||||
|
|
||||||
def new_inline_callbacks(f):
|
def new_inline_callbacks(
|
||||||
|
f: Callable[..., Generator["Deferred[object]", object, T]]
|
||||||
|
) -> Callable[..., "Deferred[T]"]:
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args: Any, **kwargs: Any) -> "Deferred[T]":
|
||||||
start_context = current_context()
|
start_context = current_context()
|
||||||
changes: List[str] = []
|
changes: List[str] = []
|
||||||
orig = orig_inline_callbacks(_check_yield_points(f, changes))
|
orig: Callable[..., "Deferred[T]"] = orig_inline_callbacks(
|
||||||
|
_check_yield_points(f, changes)
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
res = orig(*args, **kwargs)
|
res: "Deferred[T]" = orig(*args, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
if current_context() != start_context:
|
if current_context() != start_context:
|
||||||
for err in changes:
|
for err in changes:
|
||||||
@ -84,7 +91,7 @@ def do_patch() -> None:
|
|||||||
print(err, file=sys.stderr)
|
print(err, file=sys.stderr)
|
||||||
raise Exception(err)
|
raise Exception(err)
|
||||||
|
|
||||||
def check_ctx(r):
|
def check_ctx(r: T) -> T:
|
||||||
if current_context() != start_context:
|
if current_context() != start_context:
|
||||||
for err in changes:
|
for err in changes:
|
||||||
print(err, file=sys.stderr)
|
print(err, file=sys.stderr)
|
||||||
@ -107,7 +114,10 @@ def do_patch() -> None:
|
|||||||
_already_patched = True
|
_already_patched = True
|
||||||
|
|
||||||
|
|
||||||
def _check_yield_points(f: Callable, changes: List[str]) -> Callable:
|
def _check_yield_points(
|
||||||
|
f: Callable[..., Generator["Deferred[object]", object, T]],
|
||||||
|
changes: List[str],
|
||||||
|
) -> Callable:
|
||||||
"""Wraps a generator that is about to be passed to defer.inlineCallbacks
|
"""Wraps a generator that is about to be passed to defer.inlineCallbacks
|
||||||
checking that after every yield the log contexts are correct.
|
checking that after every yield the log contexts are correct.
|
||||||
|
|
||||||
@ -127,7 +137,9 @@ def _check_yield_points(f: Callable, changes: List[str]) -> Callable:
|
|||||||
from synapse.logging.context import current_context
|
from synapse.logging.context import current_context
|
||||||
|
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def check_yield_points_inner(*args, **kwargs):
|
def check_yield_points_inner(
|
||||||
|
*args: Any, **kwargs: Any
|
||||||
|
) -> Generator["Deferred[object]", object, T]:
|
||||||
gen = f(*args, **kwargs)
|
gen = f(*args, **kwargs)
|
||||||
|
|
||||||
last_yield_line_no = gen.gi_frame.f_lineno
|
last_yield_line_no = gen.gi_frame.f_lineno
|
||||||
|
@ -15,14 +15,18 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
version_cache: Dict[ModuleType, str] = {}
|
||||||
|
|
||||||
def get_version_string(module) -> str:
|
|
||||||
|
def get_version_string(module: ModuleType) -> str:
|
||||||
"""Given a module calculate a git-aware version string for it.
|
"""Given a module calculate a git-aware version string for it.
|
||||||
|
|
||||||
If called on a module not in a git checkout will return `__verison__`.
|
If called on a module not in a git checkout will return `__version__`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module (module)
|
module (module)
|
||||||
@ -31,11 +35,13 @@ def get_version_string(module) -> str:
|
|||||||
str
|
str
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cached_version = getattr(module, "_synapse_version_string_cache", None)
|
cached_version = version_cache.get(module)
|
||||||
if cached_version:
|
if cached_version is not None:
|
||||||
return cached_version
|
return cached_version
|
||||||
|
|
||||||
version_string = module.__version__
|
# We want this to fail loudly with an AttributeError. Type-ignore this so
|
||||||
|
# mypy only considers the happy path.
|
||||||
|
version_string = module.__version__ # type: ignore[attr-defined]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
null = open(os.devnull, "w")
|
null = open(os.devnull, "w")
|
||||||
@ -97,10 +103,15 @@ def get_version_string(module) -> str:
|
|||||||
s for s in (git_branch, git_tag, git_commit, git_dirty) if s
|
s for s in (git_branch, git_tag, git_commit, git_dirty) if s
|
||||||
)
|
)
|
||||||
|
|
||||||
version_string = "%s (%s)" % (module.__version__, git_version)
|
version_string = "%s (%s)" % (
|
||||||
|
# If the __version__ attribute doesn't exist, we'll have failed
|
||||||
|
# loudly above.
|
||||||
|
module.__version__, # type: ignore[attr-defined]
|
||||||
|
git_version,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info("Failed to check for git repository: %s", e)
|
logger.info("Failed to check for git repository: %s", e)
|
||||||
|
|
||||||
module._synapse_version_string_cache = version_string
|
version_cache[module] = version_string
|
||||||
|
|
||||||
return version_string
|
return version_string
|
||||||
|
Loading…
Reference in New Issue
Block a user