diff --git a/changelog.d/12769.misc b/changelog.d/12769.misc new file mode 100644 index 000000000..27bd53abe --- /dev/null +++ b/changelog.d/12769.misc @@ -0,0 +1 @@ +Tweak the mypy plugin so that `@cached` can accept `on_invalidate=None`. diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index c77586521..d08517a95 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -21,7 +21,7 @@ from typing import Callable, Optional, Type from mypy.nodes import ARG_NAMED_OPT from mypy.plugin import MethodSigContext, Plugin from mypy.typeops import bind_self -from mypy.types import CallableType, NoneType +from mypy.types import CallableType, NoneType, UnionType class SynapsePlugin(Plugin): @@ -72,13 +72,20 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: # Third, we add an optional "on_invalidate" argument. # - # This is a callable which accepts no input and returns nothing. - calltyp = CallableType( - arg_types=[], - arg_kinds=[], - arg_names=[], - ret_type=NoneType(), - fallback=ctx.api.named_generic_type("builtins.function", []), + # This is a either + # - a callable which accepts no input and returns nothing, or + # - None. + calltyp = UnionType( + [ + NoneType(), + CallableType( + arg_types=[], + arg_kinds=[], + arg_names=[], + ret_type=NoneType(), + fallback=ctx.api.named_generic_type("builtins.function", []), + ), + ] ) arg_types.append(calltyp) @@ -95,7 +102,7 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: def plugin(version: str) -> Type[SynapsePlugin]: - # This is the entry point of the plugin, and let's us deal with the fact + # This is the entry point of the plugin, and lets us deal with the fact # that the mypy plugin interface is *not* stable by looking at the version # string. # diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 608d40dfa..cc528fcf2 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -15,6 +15,7 @@ import logging from typing import ( TYPE_CHECKING, + Callable, Collection, Dict, FrozenSet, @@ -634,7 +635,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) async def get_rooms_for_user( - self, user_id: str, on_invalidate=None + self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None ) -> FrozenSet[str]: """Returns a set of room_ids the user is currently joined to.