mirror of
				https://git.anonymousland.org/anonymousland/synapse.git
				synced 2025-10-31 16:28:57 -04:00 
			
		
		
		
	Add types to synapse.util. (#10601)
This commit is contained in:
		
							parent
							
								
									ceab5a4bfa
								
							
						
					
					
						commit
						524b8ead77
					
				
					 41 changed files with 400 additions and 253 deletions
				
			
		
							
								
								
									
										1
									
								
								changelog.d/10601.misc
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								changelog.d/10601.misc
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1 @@ | |||
| Add type annotations to the synapse.util package. | ||||
							
								
								
									
										75
									
								
								mypy.ini
									
										
									
									
									
								
							
							
						
						
									
										75
									
								
								mypy.ini
									
										
									
									
									
								
							|  | @ -74,17 +74,7 @@ files = | |||
|   synapse/storage/util, | ||||
|   synapse/streams, | ||||
|   synapse/types.py, | ||||
|   synapse/util/async_helpers.py, | ||||
|   synapse/util/caches, | ||||
|   synapse/util/daemonize.py, | ||||
|   synapse/util/hash.py, | ||||
|   synapse/util/iterutils.py, | ||||
|   synapse/util/linked_list.py, | ||||
|   synapse/util/metrics.py, | ||||
|   synapse/util/macaroons.py, | ||||
|   synapse/util/module_loader.py, | ||||
|   synapse/util/msisdn.py, | ||||
|   synapse/util/stringutils.py, | ||||
|   synapse/util, | ||||
|   synapse/visibility.py, | ||||
|   tests/replication, | ||||
|   tests/test_event_auth.py, | ||||
|  | @ -102,6 +92,69 @@ files = | |||
| [mypy-synapse.rest.client.*] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.batching_queue] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.caches.dictionary_cache] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.file_consumer] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.frozenutils] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.hash] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.httpresourcetree] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.iterutils] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.linked_list] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.logcontext] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.logformatter] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.macaroons] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.manhole] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.module_loader] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.msisdn] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.ratelimitutils] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.retryutils] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.rlimit] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.stringutils] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.templates] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.threepids] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.util.wheel_timer] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-pymacaroons.*] | ||||
| ignore_missing_imports = True | ||||
| 
 | ||||
|  |  | |||
|  | @ -73,4 +73,4 @@ class RedisFactory(protocol.ReconnectingClientFactory): | |||
|     def buildProtocol(self, addr) -> RedisProtocol: ... | ||||
| 
 | ||||
| class SubscriberFactory(RedisFactory): | ||||
|     def __init__(self): ... | ||||
|     def __init__(self) -> None: ... | ||||
|  |  | |||
|  | @ -46,7 +46,7 @@ class Ratelimiter: | |||
|         #   * How many times an action has occurred since a point in time | ||||
|         #   * The point in time | ||||
|         #   * The rate_hz of this particular entry. This can vary per request | ||||
|         self.actions: OrderedDict[Hashable, Tuple[float, int, float]] = OrderedDict() | ||||
|         self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict() | ||||
| 
 | ||||
|     async def can_do_action( | ||||
|         self, | ||||
|  | @ -56,7 +56,7 @@ class Ratelimiter: | |||
|         burst_count: Optional[int] = None, | ||||
|         update: bool = True, | ||||
|         n_actions: int = 1, | ||||
|         _time_now_s: Optional[int] = None, | ||||
|         _time_now_s: Optional[float] = None, | ||||
|     ) -> Tuple[bool, float]: | ||||
|         """Can the entity (e.g. user or IP address) perform the action? | ||||
| 
 | ||||
|  | @ -160,7 +160,7 @@ class Ratelimiter: | |||
| 
 | ||||
|         return allowed, time_allowed | ||||
| 
 | ||||
|     def _prune_message_counts(self, time_now_s: int): | ||||
|     def _prune_message_counts(self, time_now_s: float): | ||||
|         """Remove message count entries that have not exceeded their defined | ||||
|         rate_hz limit | ||||
| 
 | ||||
|  | @ -188,7 +188,7 @@ class Ratelimiter: | |||
|         burst_count: Optional[int] = None, | ||||
|         update: bool = True, | ||||
|         n_actions: int = 1, | ||||
|         _time_now_s: Optional[int] = None, | ||||
|         _time_now_s: Optional[float] = None, | ||||
|     ): | ||||
|         """Checks if an action can be performed. If not, raises a LimitExceededError | ||||
| 
 | ||||
|  |  | |||
|  | @ -14,6 +14,8 @@ | |||
| 
 | ||||
| from typing import Dict, Optional | ||||
| 
 | ||||
| import attr | ||||
| 
 | ||||
| from ._base import Config | ||||
| 
 | ||||
| 
 | ||||
|  | @ -29,18 +31,13 @@ class RateLimitConfig: | |||
|         self.burst_count = int(config.get("burst_count", defaults["burst_count"])) | ||||
| 
 | ||||
| 
 | ||||
| @attr.s(auto_attribs=True) | ||||
| class FederationRateLimitConfig: | ||||
|     _items_and_default = { | ||||
|         "window_size": 1000, | ||||
|         "sleep_limit": 10, | ||||
|         "sleep_delay": 500, | ||||
|         "reject_limit": 50, | ||||
|         "concurrent": 3, | ||||
|     } | ||||
| 
 | ||||
|     def __init__(self, **kwargs): | ||||
|         for i in self._items_and_default.keys(): | ||||
|             setattr(self, i, kwargs.get(i) or self._items_and_default[i]) | ||||
|     window_size: int = 1000 | ||||
|     sleep_limit: int = 10 | ||||
|     sleep_delay: int = 500 | ||||
|     reject_limit: int = 50 | ||||
|     concurrent: int = 3 | ||||
| 
 | ||||
| 
 | ||||
| class RatelimitConfig(Config): | ||||
|  | @ -69,11 +66,15 @@ class RatelimitConfig(Config): | |||
|         else: | ||||
|             self.rc_federation = FederationRateLimitConfig( | ||||
|                 **{ | ||||
|                     "window_size": config.get("federation_rc_window_size"), | ||||
|                     "sleep_limit": config.get("federation_rc_sleep_limit"), | ||||
|                     "sleep_delay": config.get("federation_rc_sleep_delay"), | ||||
|                     "reject_limit": config.get("federation_rc_reject_limit"), | ||||
|                     "concurrent": config.get("federation_rc_concurrent"), | ||||
|                     k: v | ||||
|                     for k, v in { | ||||
|                         "window_size": config.get("federation_rc_window_size"), | ||||
|                         "sleep_limit": config.get("federation_rc_sleep_limit"), | ||||
|                         "sleep_delay": config.get("federation_rc_sleep_delay"), | ||||
|                         "reject_limit": config.get("federation_rc_reject_limit"), | ||||
|                         "concurrent": config.get("federation_rc_concurrent"), | ||||
|                     }.items() | ||||
|                     if v is not None | ||||
|                 } | ||||
|             ) | ||||
| 
 | ||||
|  |  | |||
|  | @ -22,6 +22,7 @@ from prometheus_client import Counter | |||
| from typing_extensions import Literal | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| from twisted.internet.interfaces import IDelayedCall | ||||
| 
 | ||||
| import synapse.metrics | ||||
| from synapse.api.presence import UserPresenceState | ||||
|  | @ -284,7 +285,9 @@ class FederationSender(AbstractFederationSender): | |||
|         ) | ||||
| 
 | ||||
|         # wake up destinations that have outstanding PDUs to be caught up | ||||
|         self._catchup_after_startup_timer = self.clock.call_later( | ||||
|         self._catchup_after_startup_timer: Optional[ | ||||
|             IDelayedCall | ||||
|         ] = self.clock.call_later( | ||||
|             CATCH_UP_STARTUP_DELAY_SEC, | ||||
|             run_as_background_process, | ||||
|             "wake_destinations_needing_catchup", | ||||
|  | @ -406,7 +409,7 @@ class FederationSender(AbstractFederationSender): | |||
| 
 | ||||
|                         now = self.clock.time_msec() | ||||
|                         ts = await self.store.get_received_ts(event.event_id) | ||||
| 
 | ||||
|                         assert ts is not None | ||||
|                         synapse.metrics.event_processing_lag_by_event.labels( | ||||
|                             "federation_sender" | ||||
|                         ).observe((now - ts) / 1000) | ||||
|  | @ -435,6 +438,7 @@ class FederationSender(AbstractFederationSender): | |||
|                 if events: | ||||
|                     now = self.clock.time_msec() | ||||
|                     ts = await self.store.get_received_ts(events[-1].event_id) | ||||
|                     assert ts is not None | ||||
| 
 | ||||
|                     synapse.metrics.event_processing_lag.labels( | ||||
|                         "federation_sender" | ||||
|  |  | |||
|  | @ -398,6 +398,7 @@ class AccountValidityHandler: | |||
|         """ | ||||
|         now = self.clock.time_msec() | ||||
|         if expiration_ts is None: | ||||
|             assert self._account_validity_period is not None | ||||
|             expiration_ts = now + self._account_validity_period | ||||
| 
 | ||||
|         await self.store.set_account_validity_for_user( | ||||
|  |  | |||
|  | @ -131,6 +131,8 @@ class ApplicationServicesHandler: | |||
| 
 | ||||
|                         now = self.clock.time_msec() | ||||
|                         ts = await self.store.get_received_ts(event.event_id) | ||||
|                         assert ts is not None | ||||
| 
 | ||||
|                         synapse.metrics.event_processing_lag_by_event.labels( | ||||
|                             "appservice_sender" | ||||
|                         ).observe((now - ts) / 1000) | ||||
|  | @ -166,6 +168,7 @@ class ApplicationServicesHandler: | |||
|                     if events: | ||||
|                         now = self.clock.time_msec() | ||||
|                         ts = await self.store.get_received_ts(events[-1].event_id) | ||||
|                         assert ts is not None | ||||
| 
 | ||||
|                         synapse.metrics.event_processing_lag.labels( | ||||
|                             "appservice_sender" | ||||
|  |  | |||
|  | @ -28,6 +28,7 @@ from bisect import bisect | |||
| from contextlib import contextmanager | ||||
| from typing import ( | ||||
|     TYPE_CHECKING, | ||||
|     Any, | ||||
|     Callable, | ||||
|     Collection, | ||||
|     Dict, | ||||
|  | @ -615,7 +616,7 @@ class PresenceHandler(BasePresenceHandler): | |||
|         super().__init__(hs) | ||||
|         self.hs = hs | ||||
|         self.server_name = hs.hostname | ||||
|         self.wheel_timer = WheelTimer() | ||||
|         self.wheel_timer: WheelTimer[str] = WheelTimer() | ||||
|         self.notifier = hs.get_notifier() | ||||
|         self._presence_enabled = hs.config.use_presence | ||||
| 
 | ||||
|  | @ -924,7 +925,7 @@ class PresenceHandler(BasePresenceHandler): | |||
| 
 | ||||
|         prev_state = await self.current_state_for_user(user_id) | ||||
| 
 | ||||
|         new_fields = {"last_active_ts": self.clock.time_msec()} | ||||
|         new_fields: Dict[str, Any] = {"last_active_ts": self.clock.time_msec()} | ||||
|         if prev_state.state == PresenceState.UNAVAILABLE: | ||||
|             new_fields["state"] = PresenceState.ONLINE | ||||
| 
 | ||||
|  |  | |||
|  | @ -73,7 +73,7 @@ class FollowerTypingHandler: | |||
|         self._room_typing: Dict[str, Set[str]] = {} | ||||
| 
 | ||||
|         self._member_last_federation_poke: Dict[RoomMember, int] = {} | ||||
|         self.wheel_timer = WheelTimer(bucket_size=5000) | ||||
|         self.wheel_timer: WheelTimer[RoomMember] = WheelTimer(bucket_size=5000) | ||||
|         self._latest_room_serial = 0 | ||||
| 
 | ||||
|         self.clock.looping_call(self._handle_timeouts, 5000) | ||||
|  |  | |||
|  | @ -330,11 +330,11 @@ class UsernameAvailabilityRestServlet(RestServlet): | |||
|                 # Artificially delay requests if rate > sleep_limit/window_size | ||||
|                 sleep_limit=1, | ||||
|                 # Amount of artificial delay to apply | ||||
|                 sleep_msec=1000, | ||||
|                 sleep_delay=1000, | ||||
|                 # Error with 429 if more than reject_limit requests are queued | ||||
|                 reject_limit=1, | ||||
|                 # Allow 1 request at a time | ||||
|                 concurrent_requests=1, | ||||
|                 concurrent=1, | ||||
|             ), | ||||
|         ) | ||||
| 
 | ||||
|  | @ -763,7 +763,10 @@ class RegisterRestServlet(RestServlet): | |||
|         Returns: | ||||
|              dictionary for response from /register | ||||
|         """ | ||||
|         result = {"user_id": user_id, "home_server": self.hs.hostname} | ||||
|         result: JsonDict = { | ||||
|             "user_id": user_id, | ||||
|             "home_server": self.hs.hostname, | ||||
|         } | ||||
|         if not params.get("inhibit_login", False): | ||||
|             device_id = params.get("device_id") | ||||
|             initial_display_name = params.get("initial_device_display_name") | ||||
|  | @ -814,7 +817,7 @@ class RegisterRestServlet(RestServlet): | |||
|             user_id, device_id, initial_display_name, is_guest=True | ||||
|         ) | ||||
| 
 | ||||
|         result = { | ||||
|         result: JsonDict = { | ||||
|             "user_id": user_id, | ||||
|             "device_id": device_id, | ||||
|             "access_token": access_token, | ||||
|  |  | |||
|  | @ -52,7 +52,7 @@ class NewUserConsentResource(DirectServeHtmlResource): | |||
|                 yield hs.config.sso.sso_template_dir | ||||
|             yield hs.config.sso.default_template_dir | ||||
| 
 | ||||
|         self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) | ||||
|         self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config) | ||||
| 
 | ||||
|     async def _async_render_GET(self, request: Request) -> None: | ||||
|         try: | ||||
|  |  | |||
|  | @ -80,7 +80,7 @@ class AccountDetailsResource(DirectServeHtmlResource): | |||
|                 yield hs.config.sso.sso_template_dir | ||||
|             yield hs.config.sso.default_template_dir | ||||
| 
 | ||||
|         self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) | ||||
|         self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config) | ||||
| 
 | ||||
|     async def _async_render_GET(self, request: Request) -> None: | ||||
|         try: | ||||
|  |  | |||
|  | @ -1091,6 +1091,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
|                 delta equal to 10% of the validity period. | ||||
|         """ | ||||
|         now_ms = self._clock.time_msec() | ||||
|         assert self._account_validity_period is not None | ||||
|         expiration_ts = now_ms + self._account_validity_period | ||||
| 
 | ||||
|         if use_delta: | ||||
|  |  | |||
|  | @ -38,6 +38,7 @@ from twisted.internet.interfaces import ( | |||
|     IReactorCore, | ||||
|     IReactorPluggableNameResolver, | ||||
|     IReactorTCP, | ||||
|     IReactorThreads, | ||||
|     IReactorTime, | ||||
| ) | ||||
| 
 | ||||
|  | @ -63,7 +64,12 @@ JsonDict = Dict[str, Any] | |||
| # Note that this seems to require inheriting *directly* from Interface in order | ||||
| # for mypy-zope to realize it is an interface. | ||||
| class ISynapseReactor( | ||||
|     IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface | ||||
|     IReactorTCP, | ||||
|     IReactorPluggableNameResolver, | ||||
|     IReactorTime, | ||||
|     IReactorCore, | ||||
|     IReactorThreads, | ||||
|     Interface, | ||||
| ): | ||||
|     """The interfaces necessary for Synapse to function.""" | ||||
| 
 | ||||
|  |  | |||
|  | @ -15,27 +15,35 @@ | |||
| import json | ||||
| import logging | ||||
| import re | ||||
| from typing import Pattern | ||||
| import typing | ||||
| from typing import Any, Callable, Dict, Generator, Pattern | ||||
| 
 | ||||
| import attr | ||||
| from frozendict import frozendict | ||||
| 
 | ||||
| from twisted.internet import defer, task | ||||
| from twisted.internet.defer import Deferred | ||||
| from twisted.internet.interfaces import IDelayedCall, IReactorTime | ||||
| from twisted.internet.task import LoopingCall | ||||
| from twisted.python.failure import Failure | ||||
| 
 | ||||
| from synapse.logging import context | ||||
| 
 | ||||
| if typing.TYPE_CHECKING: | ||||
|     pass | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| _WILDCARD_RUN = re.compile(r"([\?\*]+)") | ||||
| 
 | ||||
| 
 | ||||
| def _reject_invalid_json(val): | ||||
| def _reject_invalid_json(val: Any) -> None: | ||||
|     """Do not allow Infinity, -Infinity, or NaN values in JSON.""" | ||||
|     raise ValueError("Invalid JSON value: '%s'" % val) | ||||
| 
 | ||||
| 
 | ||||
| def _handle_frozendict(obj): | ||||
| def _handle_frozendict(obj: Any) -> Dict[Any, Any]: | ||||
|     """Helper for json_encoder. Makes frozendicts serializable by returning | ||||
|     the underlying dict | ||||
|     """ | ||||
|  | @ -60,10 +68,10 @@ json_encoder = json.JSONEncoder( | |||
| json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json) | ||||
| 
 | ||||
| 
 | ||||
| def unwrapFirstError(failure): | ||||
| def unwrapFirstError(failure: Failure) -> Failure: | ||||
|     # defer.gatherResults and DeferredLists wrap failures. | ||||
|     failure.trap(defer.FirstError) | ||||
|     return failure.value.subFailure | ||||
|     return failure.value.subFailure  # type: ignore[union-attr]  # Issue in Twisted's annotations | ||||
| 
 | ||||
| 
 | ||||
| @attr.s(slots=True) | ||||
|  | @ -75,25 +83,25 @@ class Clock: | |||
|         reactor: The Twisted reactor to use. | ||||
|     """ | ||||
| 
 | ||||
|     _reactor = attr.ib() | ||||
|     _reactor: IReactorTime = attr.ib() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def sleep(self, seconds): | ||||
|         d = defer.Deferred() | ||||
|     @defer.inlineCallbacks  # type: ignore[arg-type]  # Issue in Twisted's type annotations | ||||
|     def sleep(self, seconds: float) -> "Generator[Deferred[float], Any, Any]": | ||||
|         d: defer.Deferred[float] = defer.Deferred() | ||||
|         with context.PreserveLoggingContext(): | ||||
|             self._reactor.callLater(seconds, d.callback, seconds) | ||||
|             res = yield d | ||||
|         return res | ||||
| 
 | ||||
|     def time(self): | ||||
|     def time(self) -> float: | ||||
|         """Returns the current system time in seconds since epoch.""" | ||||
|         return self._reactor.seconds() | ||||
| 
 | ||||
|     def time_msec(self): | ||||
|     def time_msec(self) -> int: | ||||
|         """Returns the current system time in milliseconds since epoch.""" | ||||
|         return int(self.time() * 1000) | ||||
| 
 | ||||
|     def looping_call(self, f, msec, *args, **kwargs): | ||||
|     def looping_call(self, f: Callable, msec: float, *args, **kwargs) -> LoopingCall: | ||||
|         """Call a function repeatedly. | ||||
| 
 | ||||
|         Waits `msec` initially before calling `f` for the first time. | ||||
|  | @ -102,8 +110,8 @@ class Clock: | |||
|         other than trivial, you probably want to wrap it in run_as_background_process. | ||||
| 
 | ||||
|         Args: | ||||
|             f(function): The function to call repeatedly. | ||||
|             msec(float): How long to wait between calls in milliseconds. | ||||
|             f: The function to call repeatedly. | ||||
|             msec: How long to wait between calls in milliseconds. | ||||
|             *args: Postional arguments to pass to function. | ||||
|             **kwargs: Key arguments to pass to function. | ||||
|         """ | ||||
|  | @ -113,7 +121,7 @@ class Clock: | |||
|         d.addErrback(log_failure, "Looping call died", consumeErrors=False) | ||||
|         return call | ||||
| 
 | ||||
|     def call_later(self, delay, callback, *args, **kwargs): | ||||
|     def call_later(self, delay, callback, *args, **kwargs) -> IDelayedCall: | ||||
|         """Call something later | ||||
| 
 | ||||
|         Note that the function will be called with no logcontext, so if it is anything | ||||
|  | @ -133,7 +141,7 @@ class Clock: | |||
|         with context.PreserveLoggingContext(): | ||||
|             return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) | ||||
| 
 | ||||
|     def cancel_call_later(self, timer, ignore_errs=False): | ||||
|     def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None: | ||||
|         try: | ||||
|             timer.cancel() | ||||
|         except Exception: | ||||
|  |  | |||
|  | @ -37,6 +37,7 @@ import attr | |||
| from typing_extensions import ContextManager | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| from twisted.internet.base import ReactorBase | ||||
| from twisted.internet.defer import CancelledError | ||||
| from twisted.internet.interfaces import IReactorTime | ||||
| from twisted.python import failure | ||||
|  | @ -268,6 +269,7 @@ class Linearizer: | |||
|         if not clock: | ||||
|             from twisted.internet import reactor | ||||
| 
 | ||||
|             assert isinstance(reactor, ReactorBase) | ||||
|             clock = Clock(reactor) | ||||
|         self._clock = clock | ||||
|         self.max_count = max_count | ||||
|  | @ -411,7 +413,7 @@ class ReadWriteLock: | |||
|     # writers and readers have been resolved. The new writer replaces the latest | ||||
|     # writer. | ||||
| 
 | ||||
|     def __init__(self): | ||||
|     def __init__(self) -> None: | ||||
|         # Latest readers queued | ||||
|         self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {} | ||||
| 
 | ||||
|  | @ -503,7 +505,7 @@ def timeout_deferred( | |||
| 
 | ||||
|     timed_out = [False] | ||||
| 
 | ||||
|     def time_it_out(): | ||||
|     def time_it_out() -> None: | ||||
|         timed_out[0] = True | ||||
| 
 | ||||
|         try: | ||||
|  | @ -550,19 +552,21 @@ def timeout_deferred( | |||
|     return new_d | ||||
| 
 | ||||
| 
 | ||||
| # This class can't be generic because it uses slots with attrs. | ||||
| # See: https://github.com/python-attrs/attrs/issues/313 | ||||
| @attr.s(slots=True, frozen=True) | ||||
| class DoneAwaitable: | ||||
| class DoneAwaitable:  # should be: Generic[R] | ||||
|     """Simple awaitable that returns the provided value.""" | ||||
| 
 | ||||
|     value = attr.ib() | ||||
|     value = attr.ib(type=Any)  # should be: R | ||||
| 
 | ||||
|     def __await__(self): | ||||
|         return self | ||||
| 
 | ||||
|     def __iter__(self): | ||||
|     def __iter__(self) -> "DoneAwaitable": | ||||
|         return self | ||||
| 
 | ||||
|     def __next__(self): | ||||
|     def __next__(self) -> None: | ||||
|         raise StopIteration(self.value) | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -122,7 +122,7 @@ class BatchingQueue(Generic[V, R]): | |||
| 
 | ||||
|         # First we create a defer and add it and the value to the list of | ||||
|         # pending items. | ||||
|         d = defer.Deferred() | ||||
|         d: defer.Deferred[R] = defer.Deferred() | ||||
|         self._next_values.setdefault(key, []).append((value, d)) | ||||
| 
 | ||||
|         # If we're not currently processing the key fire off a background | ||||
|  |  | |||
|  | @ -64,32 +64,32 @@ class CacheMetric: | |||
|     evicted_size = attr.ib(default=0) | ||||
|     memory_usage = attr.ib(default=None) | ||||
| 
 | ||||
|     def inc_hits(self): | ||||
|     def inc_hits(self) -> None: | ||||
|         self.hits += 1 | ||||
| 
 | ||||
|     def inc_misses(self): | ||||
|     def inc_misses(self) -> None: | ||||
|         self.misses += 1 | ||||
| 
 | ||||
|     def inc_evictions(self, size=1): | ||||
|     def inc_evictions(self, size: int = 1) -> None: | ||||
|         self.evicted_size += size | ||||
| 
 | ||||
|     def inc_memory_usage(self, memory: int): | ||||
|     def inc_memory_usage(self, memory: int) -> None: | ||||
|         if self.memory_usage is None: | ||||
|             self.memory_usage = 0 | ||||
| 
 | ||||
|         self.memory_usage += memory | ||||
| 
 | ||||
|     def dec_memory_usage(self, memory: int): | ||||
|     def dec_memory_usage(self, memory: int) -> None: | ||||
|         self.memory_usage -= memory | ||||
| 
 | ||||
|     def clear_memory_usage(self): | ||||
|     def clear_memory_usage(self) -> None: | ||||
|         if self.memory_usage is not None: | ||||
|             self.memory_usage = 0 | ||||
| 
 | ||||
|     def describe(self): | ||||
|         return [] | ||||
| 
 | ||||
|     def collect(self): | ||||
|     def collect(self) -> None: | ||||
|         try: | ||||
|             if self._cache_type == "response_cache": | ||||
|                 response_cache_size.labels(self._cache_name).set(len(self._cache)) | ||||
|  |  | |||
|  | @ -93,7 +93,7 @@ class DeferredCache(Generic[KT, VT]): | |||
|             TreeCache, "MutableMapping[KT, CacheEntry]" | ||||
|         ] = cache_type() | ||||
| 
 | ||||
|         def metrics_cb(): | ||||
|         def metrics_cb() -> None: | ||||
|             cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) | ||||
| 
 | ||||
|         # cache is used for completed results and maps to the result itself, rather than | ||||
|  | @ -113,7 +113,7 @@ class DeferredCache(Generic[KT, VT]): | |||
|     def max_entries(self): | ||||
|         return self.cache.max_size | ||||
| 
 | ||||
|     def check_thread(self): | ||||
|     def check_thread(self) -> None: | ||||
|         expected_thread = self.thread | ||||
|         if expected_thread is None: | ||||
|             self.thread = threading.current_thread() | ||||
|  | @ -235,7 +235,7 @@ class DeferredCache(Generic[KT, VT]): | |||
| 
 | ||||
|         self._pending_deferred_cache[key] = entry | ||||
| 
 | ||||
|         def compare_and_pop(): | ||||
|         def compare_and_pop() -> bool: | ||||
|             """Check if our entry is still the one in _pending_deferred_cache, and | ||||
|             if so, pop it. | ||||
| 
 | ||||
|  | @ -256,7 +256,7 @@ class DeferredCache(Generic[KT, VT]): | |||
| 
 | ||||
|             return False | ||||
| 
 | ||||
|         def cb(result): | ||||
|         def cb(result) -> None: | ||||
|             if compare_and_pop(): | ||||
|                 self.cache.set(key, result, entry.callbacks) | ||||
|             else: | ||||
|  | @ -268,7 +268,7 @@ class DeferredCache(Generic[KT, VT]): | |||
|                 # not have been. Either way, let's double-check now. | ||||
|                 entry.invalidate() | ||||
| 
 | ||||
|         def eb(_fail): | ||||
|         def eb(_fail) -> None: | ||||
|             compare_and_pop() | ||||
|             entry.invalidate() | ||||
| 
 | ||||
|  | @ -314,7 +314,7 @@ class DeferredCache(Generic[KT, VT]): | |||
|             for entry in iterate_tree_cache_entry(entry): | ||||
|                 entry.invalidate() | ||||
| 
 | ||||
|     def invalidate_all(self): | ||||
|     def invalidate_all(self) -> None: | ||||
|         self.check_thread() | ||||
|         self.cache.clear() | ||||
|         for entry in self._pending_deferred_cache.values(): | ||||
|  | @ -332,7 +332,7 @@ class CacheEntry: | |||
|         self.callbacks = set(callbacks) | ||||
|         self.invalidated = False | ||||
| 
 | ||||
|     def invalidate(self): | ||||
|     def invalidate(self) -> None: | ||||
|         if not self.invalidated: | ||||
|             self.invalidated = True | ||||
|             for callback in self.callbacks: | ||||
|  |  | |||
|  | @ -27,10 +27,14 @@ logger = logging.getLogger(__name__) | |||
| KT = TypeVar("KT") | ||||
| # The type of the dictionary keys. | ||||
| DKT = TypeVar("DKT") | ||||
| # The type of the dictionary values. | ||||
| DV = TypeVar("DV") | ||||
| 
 | ||||
| 
 | ||||
| # This class can't be generic because it uses slots with attrs. | ||||
| # See: https://github.com/python-attrs/attrs/issues/313 | ||||
| @attr.s(slots=True) | ||||
| class DictionaryEntry: | ||||
| class DictionaryEntry:  # should be: Generic[DKT, DV]. | ||||
|     """Returned when getting an entry from the cache | ||||
| 
 | ||||
|     Attributes: | ||||
|  | @ -43,10 +47,10 @@ class DictionaryEntry: | |||
|     """ | ||||
| 
 | ||||
|     full = attr.ib(type=bool) | ||||
|     known_absent = attr.ib() | ||||
|     value = attr.ib() | ||||
|     known_absent = attr.ib(type=Set[Any])  # should be: Set[DKT] | ||||
|     value = attr.ib(type=Dict[Any, Any])  # should be: Dict[DKT, DV] | ||||
| 
 | ||||
|     def __len__(self): | ||||
|     def __len__(self) -> int: | ||||
|         return len(self.value) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -56,7 +60,7 @@ class _Sentinel(enum.Enum): | |||
|     sentinel = object() | ||||
| 
 | ||||
| 
 | ||||
| class DictionaryCache(Generic[KT, DKT]): | ||||
| class DictionaryCache(Generic[KT, DKT, DV]): | ||||
|     """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. | ||||
|     fetching a subset of dictionary keys for a particular key. | ||||
|     """ | ||||
|  | @ -87,7 +91,7 @@ class DictionaryCache(Generic[KT, DKT]): | |||
| 
 | ||||
|         Args: | ||||
|             key | ||||
|             dict_key: If given a set of keys then return only those keys | ||||
|             dict_keys: If given a set of keys then return only those keys | ||||
|                 that exist in the cache. | ||||
| 
 | ||||
|         Returns: | ||||
|  | @ -125,7 +129,7 @@ class DictionaryCache(Generic[KT, DKT]): | |||
|         self, | ||||
|         sequence: int, | ||||
|         key: KT, | ||||
|         value: Dict[DKT, Any], | ||||
|         value: Dict[DKT, DV], | ||||
|         fetched_keys: Optional[Set[DKT]] = None, | ||||
|     ) -> None: | ||||
|         """Updates the entry in the cache | ||||
|  | @ -151,15 +155,15 @@ class DictionaryCache(Generic[KT, DKT]): | |||
|                 self._update_or_insert(key, value, fetched_keys) | ||||
| 
 | ||||
|     def _update_or_insert( | ||||
|         self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT] | ||||
|         self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT] | ||||
|     ) -> None: | ||||
|         # We pop and reinsert as we need to tell the cache the size may have | ||||
|         # changed | ||||
| 
 | ||||
|         entry = self.cache.pop(key, DictionaryEntry(False, set(), {})) | ||||
|         entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {})) | ||||
|         entry.value.update(value) | ||||
|         entry.known_absent.update(known_absent) | ||||
|         self.cache[key] = entry | ||||
| 
 | ||||
|     def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None: | ||||
|     def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None: | ||||
|         self.cache[key] = DictionaryEntry(True, known_absent, value) | ||||
|  |  | |||
|  | @ -35,6 +35,7 @@ from typing import ( | |||
| from typing_extensions import Literal | ||||
| 
 | ||||
| from twisted.internet import reactor | ||||
| from twisted.internet.interfaces import IReactorTime | ||||
| 
 | ||||
| from synapse.config import cache as cache_config | ||||
| from synapse.metrics.background_process_metrics import wrap_as_background_process | ||||
|  | @ -341,7 +342,7 @@ class LruCache(Generic[KT, VT]): | |||
|         # Default `clock` to something sensible. Note that we rename it to | ||||
|         # `real_clock` so that mypy doesn't think its still `Optional`. | ||||
|         if clock is None: | ||||
|             real_clock = Clock(reactor) | ||||
|             real_clock = Clock(cast(IReactorTime, reactor)) | ||||
|         else: | ||||
|             real_clock = clock | ||||
| 
 | ||||
|  | @ -384,7 +385,7 @@ class LruCache(Generic[KT, VT]): | |||
| 
 | ||||
|         lock = threading.Lock() | ||||
| 
 | ||||
|         def evict(): | ||||
|         def evict() -> None: | ||||
|             while cache_len() > self.max_size: | ||||
|                 # Get the last node in the list (i.e. the oldest node). | ||||
|                 todelete = list_root.prev_node | ||||
|  |  | |||
|  | @ -195,7 +195,7 @@ class StreamChangeCache: | |||
|             for entity in r: | ||||
|                 del self._entity_to_key[entity] | ||||
| 
 | ||||
|     def _evict(self): | ||||
|     def _evict(self) -> None: | ||||
|         while len(self._cache) > self._max_size: | ||||
|             k, r = self._cache.popitem(0) | ||||
|             self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) | ||||
|  |  | |||
|  | @ -35,17 +35,17 @@ class TreeCache: | |||
|         root = {key_1: {key_2: _value}} | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self): | ||||
|         self.size = 0 | ||||
|     def __init__(self) -> None: | ||||
|         self.size: int = 0 | ||||
|         self.root = TreeCacheNode() | ||||
| 
 | ||||
|     def __setitem__(self, key, value): | ||||
|         return self.set(key, value) | ||||
|     def __setitem__(self, key, value) -> None: | ||||
|         self.set(key, value) | ||||
| 
 | ||||
|     def __contains__(self, key): | ||||
|     def __contains__(self, key) -> bool: | ||||
|         return self.get(key, SENTINEL) is not SENTINEL | ||||
| 
 | ||||
|     def set(self, key, value): | ||||
|     def set(self, key, value) -> None: | ||||
|         if isinstance(value, TreeCacheNode): | ||||
|             # this would mean we couldn't tell where our tree ended and the value | ||||
|             # started. | ||||
|  | @ -73,7 +73,7 @@ class TreeCache: | |||
|                 return default | ||||
|         return node.get(key[-1], default) | ||||
| 
 | ||||
|     def clear(self): | ||||
|     def clear(self) -> None: | ||||
|         self.size = 0 | ||||
|         self.root = TreeCacheNode() | ||||
| 
 | ||||
|  | @ -128,7 +128,7 @@ class TreeCache: | |||
|     def values(self): | ||||
|         return iterate_tree_cache_entry(self.root) | ||||
| 
 | ||||
|     def __len__(self): | ||||
|     def __len__(self) -> int: | ||||
|         return self.size | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -126,7 +126,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - | |||
|     signal.signal(signal.SIGTERM, sigterm) | ||||
| 
 | ||||
|     # Cleanup pid file at exit. | ||||
|     def exit(): | ||||
|     def exit() -> None: | ||||
|         logger.warning("Stopping daemon.") | ||||
|         os.remove(pid_file) | ||||
|         sys.exit(0) | ||||
|  |  | |||
|  | @ -12,6 +12,7 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import logging | ||||
| from typing import Any, Callable, Dict, List | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
|  | @ -37,11 +38,11 @@ class Distributor: | |||
|       model will do for today. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self): | ||||
|         self.signals = {} | ||||
|         self.pre_registration = {} | ||||
|     def __init__(self) -> None: | ||||
|         self.signals: Dict[str, Signal] = {} | ||||
|         self.pre_registration: Dict[str, List[Callable]] = {} | ||||
| 
 | ||||
|     def declare(self, name): | ||||
|     def declare(self, name: str) -> None: | ||||
|         if name in self.signals: | ||||
|             raise KeyError("%r already has a signal named %s" % (self, name)) | ||||
| 
 | ||||
|  | @ -52,7 +53,7 @@ class Distributor: | |||
|             for observer in self.pre_registration[name]: | ||||
|                 signal.observe(observer) | ||||
| 
 | ||||
|     def observe(self, name, observer): | ||||
|     def observe(self, name: str, observer: Callable) -> None: | ||||
|         if name in self.signals: | ||||
|             self.signals[name].observe(observer) | ||||
|         else: | ||||
|  | @ -62,7 +63,7 @@ class Distributor: | |||
|                 self.pre_registration[name] = [] | ||||
|             self.pre_registration[name].append(observer) | ||||
| 
 | ||||
|     def fire(self, name, *args, **kwargs): | ||||
|     def fire(self, name: str, *args, **kwargs) -> None: | ||||
|         """Dispatches the given signal to the registered observers. | ||||
| 
 | ||||
|         Runs the observers as a background process. Does not return a deferred. | ||||
|  | @ -83,18 +84,18 @@ class Signal: | |||
|     method into all of the observers. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, name): | ||||
|         self.name = name | ||||
|         self.observers = [] | ||||
|     def __init__(self, name: str): | ||||
|         self.name: str = name | ||||
|         self.observers: List[Callable] = [] | ||||
| 
 | ||||
|     def observe(self, observer): | ||||
|     def observe(self, observer: Callable) -> None: | ||||
|         """Adds a new callable to the observer list which will be invoked by | ||||
|         the 'fire' method. | ||||
| 
 | ||||
|         Each observer callable may return a Deferred.""" | ||||
|         self.observers.append(observer) | ||||
| 
 | ||||
|     def fire(self, *args, **kwargs): | ||||
|     def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]": | ||||
|         """Invokes every callable in the observer list, passing in the args and | ||||
|         kwargs. Exceptions thrown by observers are logged but ignored. It is | ||||
|         not an error to fire a signal with no observers. | ||||
|  |  | |||
|  | @ -13,10 +13,14 @@ | |||
| # limitations under the License. | ||||
| 
 | ||||
| import queue | ||||
| from typing import BinaryIO, Optional, Union, cast | ||||
| 
 | ||||
| from twisted.internet import threads | ||||
| from twisted.internet.defer import Deferred | ||||
| from twisted.internet.interfaces import IPullProducer, IPushProducer | ||||
| 
 | ||||
| from synapse.logging.context import make_deferred_yieldable, run_in_background | ||||
| from synapse.types import ISynapseReactor | ||||
| 
 | ||||
| 
 | ||||
| class BackgroundFileConsumer: | ||||
|  | @ -24,9 +28,9 @@ class BackgroundFileConsumer: | |||
|     and pull producers | ||||
| 
 | ||||
|     Args: | ||||
|         file_obj (file): The file like object to write to. Closed when | ||||
|         file_obj: The file like object to write to. Closed when | ||||
|             finished. | ||||
|         reactor (twisted.internet.reactor): the Twisted reactor to use | ||||
|         reactor: the Twisted reactor to use | ||||
|     """ | ||||
| 
 | ||||
|     # For PushProducers pause if we have this many unwritten slices | ||||
|  | @ -34,13 +38,13 @@ class BackgroundFileConsumer: | |||
|     # And resume once the size of the queue is less than this | ||||
|     _RESUME_ON_QUEUE_SIZE = 2 | ||||
| 
 | ||||
|     def __init__(self, file_obj, reactor): | ||||
|         self._file_obj = file_obj | ||||
|     def __init__(self, file_obj: BinaryIO, reactor: ISynapseReactor) -> None: | ||||
|         self._file_obj: BinaryIO = file_obj | ||||
| 
 | ||||
|         self._reactor = reactor | ||||
|         self._reactor: ISynapseReactor = reactor | ||||
| 
 | ||||
|         # Producer we're registered with | ||||
|         self._producer = None | ||||
|         self._producer: Optional[Union[IPushProducer, IPullProducer]] = None | ||||
| 
 | ||||
|         # True if PushProducer, false if PullProducer | ||||
|         self.streaming = False | ||||
|  | @ -51,20 +55,22 @@ class BackgroundFileConsumer: | |||
| 
 | ||||
|         # Queue of slices of bytes to be written. When producer calls | ||||
|         # unregister a final None is sent. | ||||
|         self._bytes_queue = queue.Queue() | ||||
|         self._bytes_queue: queue.Queue[Optional[bytes]] = queue.Queue() | ||||
| 
 | ||||
|         # Deferred that is resolved when finished writing | ||||
|         self._finished_deferred = None | ||||
|         self._finished_deferred: Optional[Deferred[None]] = None | ||||
| 
 | ||||
|         # If the _writer thread throws an exception it gets stored here. | ||||
|         self._write_exception = None | ||||
|         self._write_exception: Optional[Exception] = None | ||||
| 
 | ||||
|     def registerProducer(self, producer, streaming): | ||||
|     def registerProducer( | ||||
|         self, producer: Union[IPushProducer, IPullProducer], streaming: bool | ||||
|     ) -> None: | ||||
|         """Part of IConsumer interface | ||||
| 
 | ||||
|         Args: | ||||
|             producer (IProducer) | ||||
|             streaming (bool): True if push based producer, False if pull | ||||
|             producer | ||||
|             streaming: True if push based producer, False if pull | ||||
|                 based. | ||||
|         """ | ||||
|         if self._producer: | ||||
|  | @ -81,29 +87,33 @@ class BackgroundFileConsumer: | |||
|         if not streaming: | ||||
|             self._producer.resumeProducing() | ||||
| 
 | ||||
|     def unregisterProducer(self): | ||||
|     def unregisterProducer(self) -> None: | ||||
|         """Part of IProducer interface""" | ||||
|         self._producer = None | ||||
|         assert self._finished_deferred is not None | ||||
|         if not self._finished_deferred.called: | ||||
|             self._bytes_queue.put_nowait(None) | ||||
| 
 | ||||
|     def write(self, bytes): | ||||
|     def write(self, write_bytes: bytes) -> None: | ||||
|         """Part of IProducer interface""" | ||||
|         if self._write_exception: | ||||
|             raise self._write_exception | ||||
| 
 | ||||
|         assert self._finished_deferred is not None | ||||
|         if self._finished_deferred.called: | ||||
|             raise Exception("consumer has closed") | ||||
| 
 | ||||
|         self._bytes_queue.put_nowait(bytes) | ||||
|         self._bytes_queue.put_nowait(write_bytes) | ||||
| 
 | ||||
|         # If this is a PushProducer and the queue is getting behind | ||||
|         # then we pause the producer. | ||||
|         if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE: | ||||
|             self._paused_producer = True | ||||
|             self._producer.pauseProducing() | ||||
|             assert self._producer is not None | ||||
|             # cast safe because `streaming` means this is an IPushProducer | ||||
|             cast(IPushProducer, self._producer).pauseProducing() | ||||
| 
 | ||||
|     def _writer(self): | ||||
|     def _writer(self) -> None: | ||||
|         """This is run in a background thread to write to the file.""" | ||||
|         try: | ||||
|             while self._producer or not self._bytes_queue.empty(): | ||||
|  | @ -130,11 +140,11 @@ class BackgroundFileConsumer: | |||
|         finally: | ||||
|             self._file_obj.close() | ||||
| 
 | ||||
|     def wait(self): | ||||
|     def wait(self) -> "Deferred[None]": | ||||
|         """Returns a deferred that resolves when finished writing to file""" | ||||
|         return make_deferred_yieldable(self._finished_deferred) | ||||
| 
 | ||||
|     def _resume_paused_producer(self): | ||||
|     def _resume_paused_producer(self) -> None: | ||||
|         """Gets called if we should resume producing after being paused""" | ||||
|         if self._paused_producer and self._producer: | ||||
|             self._paused_producer = False | ||||
|  |  | |||
|  | @ -11,11 +11,12 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| from typing import Any | ||||
| 
 | ||||
| from frozendict import frozendict | ||||
| 
 | ||||
| 
 | ||||
| def freeze(o): | ||||
| def freeze(o: Any) -> Any: | ||||
|     if isinstance(o, dict): | ||||
|         return frozendict({k: freeze(v) for k, v in o.items()}) | ||||
| 
 | ||||
|  | @ -33,7 +34,7 @@ def freeze(o): | |||
|     return o | ||||
| 
 | ||||
| 
 | ||||
| def unfreeze(o): | ||||
| def unfreeze(o: Any) -> Any: | ||||
|     if isinstance(o, (dict, frozendict)): | ||||
|         return {k: unfreeze(v) for k, v in o.items()} | ||||
| 
 | ||||
|  |  | |||
|  | @ -13,42 +13,43 @@ | |||
| # limitations under the License. | ||||
| 
 | ||||
| import logging | ||||
| from typing import Dict | ||||
| 
 | ||||
| from twisted.web.resource import NoResource | ||||
| from twisted.web.resource import NoResource, Resource | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| def create_resource_tree(desired_tree, root_resource): | ||||
| def create_resource_tree( | ||||
|     desired_tree: Dict[str, Resource], root_resource: Resource | ||||
| ) -> Resource: | ||||
|     """Create the resource tree for this homeserver. | ||||
| 
 | ||||
|     This in unduly complicated because Twisted does not support putting | ||||
|     child resources more than 1 level deep at a time. | ||||
| 
 | ||||
|     Args: | ||||
|         web_client (bool): True to enable the web client. | ||||
|         root_resource (twisted.web.resource.Resource): The root | ||||
|             resource to add the tree to. | ||||
|         desired_tree: Dict from desired paths to desired resources. | ||||
|         root_resource: The root resource to add the tree to. | ||||
|     Returns: | ||||
|         twisted.web.resource.Resource: the ``root_resource`` with a tree of | ||||
|         child resources added to it. | ||||
|         The ``root_resource`` with a tree of child resources added to it. | ||||
|     """ | ||||
| 
 | ||||
|     # ideally we'd just use getChild and putChild but getChild doesn't work | ||||
|     # unless you give it a Request object IN ADDITION to the name :/ So | ||||
|     # instead, we'll store a copy of this mapping so we can actually add | ||||
|     # extra resources to existing nodes. See self._resource_id for the key. | ||||
|     resource_mappings = {} | ||||
|     for full_path, res in desired_tree.items(): | ||||
|     resource_mappings: Dict[str, Resource] = {} | ||||
|     for full_path_str, res in desired_tree.items(): | ||||
|         # twisted requires all resources to be bytes | ||||
|         full_path = full_path.encode("utf-8") | ||||
|         full_path = full_path_str.encode("utf-8") | ||||
| 
 | ||||
|         logger.info("Attaching %s to path %s", res, full_path) | ||||
|         last_resource = root_resource | ||||
|         for path_seg in full_path.split(b"/")[1:-1]: | ||||
|             if path_seg not in last_resource.listNames(): | ||||
|                 # resource doesn't exist, so make a "dummy resource" | ||||
|                 child_resource = NoResource() | ||||
|                 child_resource: Resource = NoResource() | ||||
|                 last_resource.putChild(path_seg, child_resource) | ||||
|                 res_id = _resource_id(last_resource, path_seg) | ||||
|                 resource_mappings[res_id] = child_resource | ||||
|  | @ -83,7 +84,7 @@ def create_resource_tree(desired_tree, root_resource): | |||
|     return root_resource | ||||
| 
 | ||||
| 
 | ||||
| def _resource_id(resource, path_seg): | ||||
| def _resource_id(resource: Resource, path_seg: bytes) -> str: | ||||
|     """Construct an arbitrary resource ID so you can retrieve the mapping | ||||
|     later. | ||||
| 
 | ||||
|  | @ -96,4 +97,4 @@ def _resource_id(resource, path_seg): | |||
|     Returns: | ||||
|         str: A unique string which can be a key to the child Resource. | ||||
|     """ | ||||
|     return "%s-%s" % (resource, path_seg) | ||||
|     return "%s-%r" % (resource, path_seg) | ||||
|  |  | |||
|  | @ -74,7 +74,7 @@ class ListNode(Generic[P]): | |||
|             new_node._refs_insert_after(node) | ||||
|         return new_node | ||||
| 
 | ||||
|     def remove_from_list(self): | ||||
|     def remove_from_list(self) -> None: | ||||
|         """Remove this node from the list.""" | ||||
|         with self._LOCK: | ||||
|             self._refs_remove_node_from_list() | ||||
|  | @ -84,7 +84,7 @@ class ListNode(Generic[P]): | |||
|         # immediately rather than at the next GC. | ||||
|         self.cache_entry = None | ||||
| 
 | ||||
|     def move_after(self, node: "ListNode"): | ||||
|     def move_after(self, node: "ListNode") -> None: | ||||
|         """Move this node from its current location in the list to after the | ||||
|         given node. | ||||
|         """ | ||||
|  | @ -103,7 +103,7 @@ class ListNode(Generic[P]): | |||
|             # Insert self back into the list, after target node | ||||
|             self._refs_insert_after(node) | ||||
| 
 | ||||
|     def _refs_remove_node_from_list(self): | ||||
|     def _refs_remove_node_from_list(self) -> None: | ||||
|         """Internal method to *just* remove the node from the list, without | ||||
|         e.g. clearing out the cache entry. | ||||
|         """ | ||||
|  | @ -122,7 +122,7 @@ class ListNode(Generic[P]): | |||
|         self.prev_node = None | ||||
|         self.next_node = None | ||||
| 
 | ||||
|     def _refs_insert_after(self, node: "ListNode"): | ||||
|     def _refs_insert_after(self, node: "ListNode") -> None: | ||||
|         """Internal method to insert the node after the given node.""" | ||||
| 
 | ||||
|         # This method should only be called when we're not already in the list. | ||||
|  |  | |||
|  | @ -77,7 +77,7 @@ def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> N | |||
|             should be considered expired. Normally the current time. | ||||
|     """ | ||||
| 
 | ||||
|     def verify_expiry_caveat(caveat: str): | ||||
|     def verify_expiry_caveat(caveat: str) -> bool: | ||||
|         time_msec = get_time_ms() | ||||
|         prefix = "time < " | ||||
|         if not caveat.startswith(prefix): | ||||
|  |  | |||
|  | @ -15,6 +15,7 @@ | |||
| import inspect | ||||
| import sys | ||||
| import traceback | ||||
| from typing import Any, Dict, Optional | ||||
| 
 | ||||
| from twisted.conch import manhole_ssh | ||||
| from twisted.conch.insults import insults | ||||
|  | @ -22,6 +23,9 @@ from twisted.conch.manhole import ColoredManhole, ManholeInterpreter | |||
| from twisted.conch.ssh.keys import Key | ||||
| from twisted.cred import checkers, portal | ||||
| from twisted.internet import defer | ||||
| from twisted.internet.protocol import Factory | ||||
| 
 | ||||
| from synapse.config.server import ManholeConfig | ||||
| 
 | ||||
| PUBLIC_KEY = ( | ||||
|     "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5" | ||||
|  | @ -61,22 +65,22 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs= | |||
| -----END RSA PRIVATE KEY-----""" | ||||
| 
 | ||||
| 
 | ||||
| def manhole(settings, globals): | ||||
| def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory: | ||||
|     """Starts a ssh listener with password authentication using | ||||
|     the given username and password. Clients connecting to the ssh | ||||
|     listener will find themselves in a colored python shell with | ||||
|     the supplied globals. | ||||
| 
 | ||||
|     Args: | ||||
|         username(str): The username ssh clients should auth with. | ||||
|         password(str): The password ssh clients should auth with. | ||||
|         globals(dict): The variables to expose in the shell. | ||||
|         username: The username ssh clients should auth with. | ||||
|         password: The password ssh clients should auth with. | ||||
|         globals: The variables to expose in the shell. | ||||
| 
 | ||||
|     Returns: | ||||
|         twisted.internet.protocol.Factory: A factory to pass to ``listenTCP`` | ||||
|         A factory to pass to ``listenTCP`` | ||||
|     """ | ||||
|     username = settings.username | ||||
|     password = settings.password | ||||
|     password = settings.password.encode("ascii") | ||||
|     priv_key = settings.priv_key | ||||
|     if priv_key is None: | ||||
|         priv_key = Key.fromString(PRIVATE_KEY) | ||||
|  | @ -84,19 +88,22 @@ def manhole(settings, globals): | |||
|     if pub_key is None: | ||||
|         pub_key = Key.fromString(PUBLIC_KEY) | ||||
| 
 | ||||
|     if not isinstance(password, bytes): | ||||
|         password = password.encode("ascii") | ||||
| 
 | ||||
|     checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password}) | ||||
| 
 | ||||
|     rlm = manhole_ssh.TerminalRealm() | ||||
|     rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( | ||||
|     # mypy ignored here because: | ||||
|     # - can't deduce types of lambdas | ||||
|     # - variable is Type[ServerProtocol], expr is Callable[[], ServerProtocol] | ||||
|     rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(  # type: ignore[misc,assignment] | ||||
|         SynapseManhole, dict(globals, __name__="__console__") | ||||
|     ) | ||||
| 
 | ||||
|     factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) | ||||
|     factory.privateKeys[b"ssh-rsa"] = priv_key | ||||
|     factory.publicKeys[b"ssh-rsa"] = pub_key | ||||
| 
 | ||||
|     # conch has the wrong type on these dicts (says bytes to bytes, | ||||
|     # should be bytes to Keys judging by how it's used). | ||||
|     factory.privateKeys[b"ssh-rsa"] = priv_key  # type: ignore[assignment] | ||||
|     factory.publicKeys[b"ssh-rsa"] = pub_key  # type: ignore[assignment] | ||||
| 
 | ||||
|     return factory | ||||
| 
 | ||||
|  | @ -104,7 +111,7 @@ def manhole(settings, globals): | |||
| class SynapseManhole(ColoredManhole): | ||||
|     """Overrides connectionMade to create our own ManholeInterpreter""" | ||||
| 
 | ||||
|     def connectionMade(self): | ||||
|     def connectionMade(self) -> None: | ||||
|         super().connectionMade() | ||||
| 
 | ||||
|         # replace the manhole interpreter with our own impl | ||||
|  | @ -114,13 +121,14 @@ class SynapseManhole(ColoredManhole): | |||
| 
 | ||||
| 
 | ||||
| class SynapseManholeInterpreter(ManholeInterpreter): | ||||
|     def showsyntaxerror(self, filename=None): | ||||
|     def showsyntaxerror(self, filename: Optional[str] = None) -> None: | ||||
|         """Display the syntax error that just occurred. | ||||
| 
 | ||||
|         Overrides the base implementation, ignoring sys.excepthook. We always want | ||||
|         any syntax errors to be sent to the terminal, rather than sentry. | ||||
|         """ | ||||
|         type, value, tb = sys.exc_info() | ||||
|         assert value is not None | ||||
|         sys.last_type = type | ||||
|         sys.last_value = value | ||||
|         sys.last_traceback = tb | ||||
|  | @ -138,7 +146,7 @@ class SynapseManholeInterpreter(ManholeInterpreter): | |||
|         lines = traceback.format_exception_only(type, value) | ||||
|         self.write("".join(lines)) | ||||
| 
 | ||||
|     def showtraceback(self): | ||||
|     def showtraceback(self) -> None: | ||||
|         """Display the exception that just occurred. | ||||
| 
 | ||||
|         Overrides the base implementation, ignoring sys.excepthook. We always want | ||||
|  | @ -146,14 +154,22 @@ class SynapseManholeInterpreter(ManholeInterpreter): | |||
|         """ | ||||
|         sys.last_type, sys.last_value, last_tb = ei = sys.exc_info() | ||||
|         sys.last_traceback = last_tb | ||||
|         assert last_tb is not None | ||||
| 
 | ||||
|         try: | ||||
|             # We remove the first stack item because it is our own code. | ||||
|             lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next) | ||||
|             self.write("".join(lines)) | ||||
|         finally: | ||||
|             last_tb = ei = None | ||||
|             # On the line below, last_tb and ei appear to be dead. | ||||
|             # It's unclear whether there is a reason behind this line. | ||||
|             # It conceivably could be because an exception raised in this block | ||||
|             # will keep the local frame (containing these local variables) around. | ||||
|             # This was adapted taken from CPython's Lib/code.py; see here: | ||||
|             # https://github.com/python/cpython/blob/4dc4300c686f543d504ab6fa9fe600eaf11bb695/Lib/code.py#L131-L150 | ||||
|             last_tb = ei = None  # type: ignore | ||||
| 
 | ||||
|     def displayhook(self, obj): | ||||
|     def displayhook(self, obj: Any) -> None: | ||||
|         """ | ||||
|         We override the displayhook so that we automatically convert coroutines | ||||
|         into Deferreds. (Our superclass' displayhook will take care of the rest, | ||||
|  |  | |||
|  | @ -24,7 +24,7 @@ from twisted.python.failure import Failure | |||
| _already_patched = False | ||||
| 
 | ||||
| 
 | ||||
| def do_patch(): | ||||
| def do_patch() -> None: | ||||
|     """ | ||||
|     Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit | ||||
|     """ | ||||
|  | @ -107,7 +107,7 @@ def do_patch(): | |||
|     _already_patched = True | ||||
| 
 | ||||
| 
 | ||||
| def _check_yield_points(f: Callable, changes: List[str]): | ||||
| def _check_yield_points(f: Callable, changes: List[str]) -> Callable: | ||||
|     """Wraps a generator that is about to be passed to defer.inlineCallbacks | ||||
|     checking that after every yield the log contexts are correct. | ||||
| 
 | ||||
|  |  | |||
|  | @ -15,33 +15,36 @@ | |||
| import collections | ||||
| import contextlib | ||||
| import logging | ||||
| import typing | ||||
| from typing import Any, DefaultDict, Iterator, List, Set | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.api.errors import LimitExceededError | ||||
| from synapse.config.ratelimiting import FederationRateLimitConfig | ||||
| from synapse.logging.context import ( | ||||
|     PreserveLoggingContext, | ||||
|     make_deferred_yieldable, | ||||
|     run_in_background, | ||||
| ) | ||||
| from synapse.util import Clock | ||||
| 
 | ||||
| if typing.TYPE_CHECKING: | ||||
|     from contextlib import _GeneratorContextManager | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class FederationRateLimiter: | ||||
|     def __init__(self, clock, config): | ||||
|         """ | ||||
|         Args: | ||||
|             clock (Clock) | ||||
|             config (FederationRateLimitConfig) | ||||
|         """ | ||||
| 
 | ||||
|         def new_limiter(): | ||||
|     def __init__(self, clock: Clock, config: FederationRateLimitConfig): | ||||
|         def new_limiter() -> "_PerHostRatelimiter": | ||||
|             return _PerHostRatelimiter(clock=clock, config=config) | ||||
| 
 | ||||
|         self.ratelimiters = collections.defaultdict(new_limiter) | ||||
|         self.ratelimiters: DefaultDict[ | ||||
|             str, "_PerHostRatelimiter" | ||||
|         ] = collections.defaultdict(new_limiter) | ||||
| 
 | ||||
|     def ratelimit(self, host): | ||||
|     def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]": | ||||
|         """Used to ratelimit an incoming request from a given host | ||||
| 
 | ||||
|         Example usage: | ||||
|  | @ -60,11 +63,11 @@ class FederationRateLimiter: | |||
| 
 | ||||
| 
 | ||||
| class _PerHostRatelimiter: | ||||
|     def __init__(self, clock, config): | ||||
|     def __init__(self, clock: Clock, config: FederationRateLimitConfig): | ||||
|         """ | ||||
|         Args: | ||||
|             clock (Clock) | ||||
|             config (FederationRateLimitConfig) | ||||
|             clock | ||||
|             config | ||||
|         """ | ||||
|         self.clock = clock | ||||
| 
 | ||||
|  | @ -75,21 +78,23 @@ class _PerHostRatelimiter: | |||
|         self.concurrent_requests = config.concurrent | ||||
| 
 | ||||
|         # request_id objects for requests which have been slept | ||||
|         self.sleeping_requests = set() | ||||
|         self.sleeping_requests: Set[object] = set() | ||||
| 
 | ||||
|         # map from request_id object to Deferred for requests which are ready | ||||
|         # for processing but have been queued | ||||
|         self.ready_request_queue = collections.OrderedDict() | ||||
|         self.ready_request_queue: collections.OrderedDict[ | ||||
|             object, defer.Deferred[None] | ||||
|         ] = collections.OrderedDict() | ||||
| 
 | ||||
|         # request id objects for requests which are in progress | ||||
|         self.current_processing = set() | ||||
|         self.current_processing: Set[object] = set() | ||||
| 
 | ||||
|         # times at which we have recently (within the last window_size ms) | ||||
|         # received requests. | ||||
|         self.request_times = [] | ||||
|         self.request_times: List[int] = [] | ||||
| 
 | ||||
|     @contextlib.contextmanager | ||||
|     def ratelimit(self): | ||||
|     def ratelimit(self) -> "Iterator[defer.Deferred[None]]": | ||||
|         # `contextlib.contextmanager` takes a generator and turns it into a | ||||
|         # context manager. The generator should only yield once with a value | ||||
|         # to be returned by manager. | ||||
|  | @ -102,7 +107,7 @@ class _PerHostRatelimiter: | |||
|         finally: | ||||
|             self._on_exit(request_id) | ||||
| 
 | ||||
|     def _on_enter(self, request_id): | ||||
|     def _on_enter(self, request_id: object) -> "defer.Deferred[None]": | ||||
|         time_now = self.clock.time_msec() | ||||
| 
 | ||||
|         # remove any entries from request_times which aren't within the window | ||||
|  | @ -120,9 +125,9 @@ class _PerHostRatelimiter: | |||
| 
 | ||||
|         self.request_times.append(time_now) | ||||
| 
 | ||||
|         def queue_request(): | ||||
|         def queue_request() -> "defer.Deferred[None]": | ||||
|             if len(self.current_processing) >= self.concurrent_requests: | ||||
|                 queue_defer = defer.Deferred() | ||||
|                 queue_defer: defer.Deferred[None] = defer.Deferred() | ||||
|                 self.ready_request_queue[request_id] = queue_defer | ||||
|                 logger.info( | ||||
|                     "Ratelimiter: queueing request (queue now %i items)", | ||||
|  | @ -145,7 +150,7 @@ class _PerHostRatelimiter: | |||
| 
 | ||||
|             self.sleeping_requests.add(request_id) | ||||
| 
 | ||||
|             def on_wait_finished(_): | ||||
|             def on_wait_finished(_: Any) -> "defer.Deferred[None]": | ||||
|                 logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id)) | ||||
|                 self.sleeping_requests.discard(request_id) | ||||
|                 queue_defer = queue_request() | ||||
|  | @ -155,19 +160,19 @@ class _PerHostRatelimiter: | |||
|         else: | ||||
|             ret_defer = queue_request() | ||||
| 
 | ||||
|         def on_start(r): | ||||
|         def on_start(r: object) -> object: | ||||
|             logger.debug("Ratelimit [%s]: Processing req", id(request_id)) | ||||
|             self.current_processing.add(request_id) | ||||
|             return r | ||||
| 
 | ||||
|         def on_err(r): | ||||
|         def on_err(r: object) -> object: | ||||
|             # XXX: why is this necessary? this is called before we start | ||||
|             # processing the request so why would the request be in | ||||
|             # current_processing? | ||||
|             self.current_processing.discard(request_id) | ||||
|             return r | ||||
| 
 | ||||
|         def on_both(r): | ||||
|         def on_both(r: object) -> object: | ||||
|             # Ensure that we've properly cleaned up. | ||||
|             self.sleeping_requests.discard(request_id) | ||||
|             self.ready_request_queue.pop(request_id, None) | ||||
|  | @ -177,7 +182,7 @@ class _PerHostRatelimiter: | |||
|         ret_defer.addBoth(on_both) | ||||
|         return make_deferred_yieldable(ret_defer) | ||||
| 
 | ||||
|     def _on_exit(self, request_id): | ||||
|     def _on_exit(self, request_id: object) -> None: | ||||
|         logger.debug("Ratelimit [%s]: Processed req", id(request_id)) | ||||
|         self.current_processing.discard(request_id) | ||||
|         try: | ||||
|  |  | |||
|  | @ -13,9 +13,13 @@ | |||
| # limitations under the License. | ||||
| import logging | ||||
| import random | ||||
| from types import TracebackType | ||||
| from typing import Any, Optional, Type | ||||
| 
 | ||||
| import synapse.logging.context | ||||
| from synapse.api.errors import CodeMessageException | ||||
| from synapse.storage import DataStore | ||||
| from synapse.util import Clock | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
|  | @ -30,17 +34,17 @@ MAX_RETRY_INTERVAL = 2 ** 62 | |||
| 
 | ||||
| 
 | ||||
| class NotRetryingDestination(Exception): | ||||
|     def __init__(self, retry_last_ts, retry_interval, destination): | ||||
|     def __init__(self, retry_last_ts: int, retry_interval: int, destination: str): | ||||
|         """Raised by the limiter (and federation client) to indicate that we are | ||||
|         are deliberately not attempting to contact a given server. | ||||
| 
 | ||||
|         Args: | ||||
|             retry_last_ts (int): the unix ts in milliseconds of our last attempt | ||||
|             retry_last_ts: the unix ts in milliseconds of our last attempt | ||||
|                 to contact the server.  0 indicates that the last attempt was | ||||
|                 successful or that we've never actually attempted to connect. | ||||
|             retry_interval (int): the time in milliseconds to wait until the next | ||||
|             retry_interval: the time in milliseconds to wait until the next | ||||
|                 attempt. | ||||
|             destination (str): the domain in question | ||||
|             destination: the domain in question | ||||
|         """ | ||||
| 
 | ||||
|         msg = "Not retrying server %s." % (destination,) | ||||
|  | @ -51,7 +55,13 @@ class NotRetryingDestination(Exception): | |||
|         self.destination = destination | ||||
| 
 | ||||
| 
 | ||||
| async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs): | ||||
| async def get_retry_limiter( | ||||
|     destination: str, | ||||
|     clock: Clock, | ||||
|     store: DataStore, | ||||
|     ignore_backoff: bool = False, | ||||
|     **kwargs: Any, | ||||
| ) -> "RetryDestinationLimiter": | ||||
|     """For a given destination check if we have previously failed to | ||||
|     send a request there and are waiting before retrying the destination. | ||||
|     If we are not ready to retry the destination, this will raise a | ||||
|  | @ -60,10 +70,10 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k | |||
|     CodeMessageException with code < 500) | ||||
| 
 | ||||
|     Args: | ||||
|         destination (str): name of homeserver | ||||
|         clock (synapse.util.clock): timing source | ||||
|         store (synapse.storage.transactions.TransactionStore): datastore | ||||
|         ignore_backoff (bool): true to ignore the historical backoff data and | ||||
|         destination: name of homeserver | ||||
|         clock: timing source | ||||
|         store: datastore | ||||
|         ignore_backoff: true to ignore the historical backoff data and | ||||
|             try the request anyway. We will still reset the retry_interval on success. | ||||
| 
 | ||||
|     Example usage: | ||||
|  | @ -114,13 +124,13 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k | |||
| class RetryDestinationLimiter: | ||||
|     def __init__( | ||||
|         self, | ||||
|         destination, | ||||
|         clock, | ||||
|         store, | ||||
|         failure_ts, | ||||
|         retry_interval, | ||||
|         backoff_on_404=False, | ||||
|         backoff_on_failure=True, | ||||
|         destination: str, | ||||
|         clock: Clock, | ||||
|         store: DataStore, | ||||
|         failure_ts: Optional[int], | ||||
|         retry_interval: int, | ||||
|         backoff_on_404: bool = False, | ||||
|         backoff_on_failure: bool = True, | ||||
|     ): | ||||
|         """Marks the destination as "down" if an exception is thrown in the | ||||
|         context, except for CodeMessageException with code < 500. | ||||
|  | @ -128,17 +138,17 @@ class RetryDestinationLimiter: | |||
|         If no exception is raised, marks the destination as "up". | ||||
| 
 | ||||
|         Args: | ||||
|             destination (str) | ||||
|             clock (Clock) | ||||
|             store (DataStore) | ||||
|             failure_ts (int|None): when this destination started failing (in ms since | ||||
|             destination | ||||
|             clock | ||||
|             store | ||||
|             failure_ts: when this destination started failing (in ms since | ||||
|                 the epoch), or zero if the last request was successful | ||||
|             retry_interval (int): The next retry interval taken from the | ||||
|             retry_interval: The next retry interval taken from the | ||||
|                 database in milliseconds, or zero if the last request was | ||||
|                 successful. | ||||
|             backoff_on_404 (bool): Back off if we get a 404 | ||||
|             backoff_on_404: Back off if we get a 404 | ||||
| 
 | ||||
|             backoff_on_failure (bool): set to False if we should not increase the | ||||
|             backoff_on_failure: set to False if we should not increase the | ||||
|                 retry interval on a failure. | ||||
|         """ | ||||
|         self.clock = clock | ||||
|  | @ -150,10 +160,15 @@ class RetryDestinationLimiter: | |||
|         self.backoff_on_404 = backoff_on_404 | ||||
|         self.backoff_on_failure = backoff_on_failure | ||||
| 
 | ||||
|     def __enter__(self): | ||||
|     def __enter__(self) -> None: | ||||
|         pass | ||||
| 
 | ||||
|     def __exit__(self, exc_type, exc_val, exc_tb): | ||||
|     def __exit__( | ||||
|         self, | ||||
|         exc_type: Optional[Type[BaseException]], | ||||
|         exc_val: Optional[BaseException], | ||||
|         exc_tb: Optional[TracebackType], | ||||
|     ) -> None: | ||||
|         valid_err_code = False | ||||
|         if exc_type is None: | ||||
|             valid_err_code = True | ||||
|  | @ -161,7 +176,7 @@ class RetryDestinationLimiter: | |||
|             # avoid treating exceptions which don't derive from Exception as | ||||
|             # failures; this is mostly so as not to catch defer._DefGen. | ||||
|             valid_err_code = True | ||||
|         elif issubclass(exc_type, CodeMessageException): | ||||
|         elif isinstance(exc_val, CodeMessageException): | ||||
|             # Some error codes are perfectly fine for some APIs, whereas other | ||||
|             # APIs may expect to never received e.g. a 404. It's important to | ||||
|             # handle 404 as some remote servers will return a 404 when the HS | ||||
|  | @ -216,7 +231,7 @@ class RetryDestinationLimiter: | |||
|             if self.failure_ts is None: | ||||
|                 self.failure_ts = retry_last_ts | ||||
| 
 | ||||
|         async def store_retry_timings(): | ||||
|         async def store_retry_timings() -> None: | ||||
|             try: | ||||
|                 await self.store.set_destination_retry_timings( | ||||
|                     self.destination, | ||||
|  |  | |||
|  | @ -18,7 +18,7 @@ import resource | |||
| logger = logging.getLogger("synapse.app.homeserver") | ||||
| 
 | ||||
| 
 | ||||
| def change_resource_limit(soft_file_no): | ||||
| def change_resource_limit(soft_file_no: int) -> None: | ||||
|     try: | ||||
|         soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) | ||||
| 
 | ||||
|  |  | |||
|  | @ -16,7 +16,7 @@ | |||
| 
 | ||||
| import time | ||||
| import urllib.parse | ||||
| from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union | ||||
| from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union | ||||
| 
 | ||||
| import jinja2 | ||||
| 
 | ||||
|  | @ -25,9 +25,9 @@ if TYPE_CHECKING: | |||
| 
 | ||||
| 
 | ||||
| def build_jinja_env( | ||||
|     template_search_directories: Iterable[str], | ||||
|     template_search_directories: Sequence[str], | ||||
|     config: "HomeServerConfig", | ||||
|     autoescape: Union[bool, Callable[[str], bool], None] = None, | ||||
|     autoescape: Union[bool, Callable[[Optional[str]], bool], None] = None, | ||||
| ) -> jinja2.Environment: | ||||
|     """Set up a Jinja2 environment to load templates from the given search path | ||||
| 
 | ||||
|  | @ -110,5 +110,5 @@ def _create_mxc_to_http_filter( | |||
|     return mxc_to_http_filter | ||||
| 
 | ||||
| 
 | ||||
| def _format_ts_filter(value: int, format: str): | ||||
| def _format_ts_filter(value: int, format: str) -> str: | ||||
|     return time.strftime(format, time.localtime(value / 1000)) | ||||
|  |  | |||
|  | @ -14,6 +14,10 @@ | |||
| 
 | ||||
| import logging | ||||
| import re | ||||
| import typing | ||||
| 
 | ||||
| if typing.TYPE_CHECKING: | ||||
|     from synapse.server import HomeServer | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
|  | @ -28,13 +32,13 @@ logger = logging.getLogger(__name__) | |||
| MAX_EMAIL_ADDRESS_LENGTH = 500 | ||||
| 
 | ||||
| 
 | ||||
| def check_3pid_allowed(hs, medium, address): | ||||
| def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool: | ||||
|     """Checks whether a given format of 3PID is allowed to be used on this HS | ||||
| 
 | ||||
|     Args: | ||||
|         hs (synapse.server.HomeServer): server | ||||
|         medium (str): 3pid medium - e.g. email, msisdn | ||||
|         address (str): address within that medium (e.g. "wotan@matrix.org") | ||||
|         hs: server | ||||
|         medium: 3pid medium - e.g. email, msisdn | ||||
|         address: address within that medium (e.g. "wotan@matrix.org") | ||||
|             msisdns need to first have been canonicalised | ||||
|     Returns: | ||||
|         bool: whether the 3PID medium/address is allowed to be added to this HS | ||||
|  |  | |||
|  | @ -19,7 +19,7 @@ import subprocess | |||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| def get_version_string(module): | ||||
| def get_version_string(module) -> str: | ||||
|     """Given a module calculate a git-aware version string for it. | ||||
| 
 | ||||
|     If called on a module not in a git checkout will return `__verison__`. | ||||
|  |  | |||
|  | @ -11,38 +11,41 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| from typing import Generic, List, TypeVar | ||||
| 
 | ||||
| T = TypeVar("T") | ||||
| 
 | ||||
| 
 | ||||
| class _Entry: | ||||
| class _Entry(Generic[T]): | ||||
|     __slots__ = ["end_key", "queue"] | ||||
| 
 | ||||
|     def __init__(self, end_key): | ||||
|         self.end_key = end_key | ||||
|         self.queue = [] | ||||
|     def __init__(self, end_key: int) -> None: | ||||
|         self.end_key: int = end_key | ||||
|         self.queue: List[T] = [] | ||||
| 
 | ||||
| 
 | ||||
| class WheelTimer: | ||||
| class WheelTimer(Generic[T]): | ||||
|     """Stores arbitrary objects that will be returned after their timers have | ||||
|     expired. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, bucket_size=5000): | ||||
|     def __init__(self, bucket_size: int = 5000) -> None: | ||||
|         """ | ||||
|         Args: | ||||
|             bucket_size (int): Size of buckets in ms. Corresponds roughly to the | ||||
|             bucket_size: Size of buckets in ms. Corresponds roughly to the | ||||
|                 accuracy of the timer. | ||||
|         """ | ||||
|         self.bucket_size = bucket_size | ||||
|         self.entries = [] | ||||
|         self.current_tick = 0 | ||||
|         self.bucket_size: int = bucket_size | ||||
|         self.entries: List[_Entry[T]] = [] | ||||
|         self.current_tick: int = 0 | ||||
| 
 | ||||
|     def insert(self, now, obj, then): | ||||
|     def insert(self, now: int, obj: T, then: int) -> None: | ||||
|         """Inserts object into timer. | ||||
| 
 | ||||
|         Args: | ||||
|             now (int): Current time in msec | ||||
|             obj (object): Object to be inserted | ||||
|             then (int): When to return the object strictly after. | ||||
|             now: Current time in msec | ||||
|             obj: Object to be inserted | ||||
|             then: When to return the object strictly after. | ||||
|         """ | ||||
|         then_key = int(then / self.bucket_size) + 1 | ||||
| 
 | ||||
|  | @ -70,7 +73,7 @@ class WheelTimer: | |||
| 
 | ||||
|         self.entries[-1].queue.append(obj) | ||||
| 
 | ||||
|     def fetch(self, now): | ||||
|     def fetch(self, now: int) -> List[T]: | ||||
|         """Fetch any objects that have timed out | ||||
| 
 | ||||
|         Args: | ||||
|  | @ -87,5 +90,5 @@ class WheelTimer: | |||
| 
 | ||||
|         return ret | ||||
| 
 | ||||
|     def __len__(self): | ||||
|     def __len__(self) -> int: | ||||
|         return sum(len(entry.queue) for entry in self.entries) | ||||
|  |  | |||
|  | @ -734,9 +734,9 @@ class TestTransportLayerServer(JsonResource): | |||
|             FederationRateLimitConfig( | ||||
|                 window_size=1, | ||||
|                 sleep_limit=1, | ||||
|                 sleep_msec=1, | ||||
|                 sleep_delay=1, | ||||
|                 reject_limit=1000, | ||||
|                 concurrent_requests=1000, | ||||
|                 concurrent=1000, | ||||
|             ), | ||||
|         ) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 reivilibre
						reivilibre