Use ParamSpec in a few places (#12667)

This commit is contained in:
David Robertson 2022-05-09 11:27:39 +01:00 committed by GitHub
parent c5969b346d
commit fa0eab9c8e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 148 additions and 68 deletions

View file

@ -12,7 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, Dict, List
from typing import (
Any,
Awaitable,
Callable,
Dict,
Generic,
List,
Optional,
TypeVar,
Union,
)
from typing_extensions import ParamSpec
from twisted.internet import defer
@ -75,7 +87,11 @@ class Distributor:
run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
class Signal:
P = ParamSpec("P")
R = TypeVar("R")
class Signal(Generic[P]):
"""A Signal is a dispatch point that stores a list of callables as
observers of it.
@ -87,16 +103,16 @@ class Signal:
def __init__(self, name: str):
self.name: str = name
self.observers: List[Callable] = []
self.observers: List[Callable[P, Any]] = []
def observe(self, observer: Callable) -> None:
def observe(self, observer: Callable[P, Any]) -> None:
"""Adds a new callable to the observer list which will be invoked by
the 'fire' method.
Each observer callable may return a Deferred."""
self.observers.append(observer)
def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]":
def fire(self, *args: P.args, **kwargs: P.kwargs) -> "defer.Deferred[List[Any]]":
"""Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is
not an error to fire a signal with no observers.
@ -104,7 +120,7 @@ class Signal:
Returns a Deferred that will complete when all the observers have
completed."""
async def do(observer: Callable[..., Any]) -> Any:
async def do(observer: Callable[P, Union[R, Awaitable[R]]]) -> Optional[R]:
try:
return await maybe_awaitable(observer(*args, **kwargs))
except Exception as e:
@ -114,6 +130,7 @@ class Signal:
observer,
e,
)
return None
deferreds = [run_in_background(do, o) for o in self.observers]