diff --git a/maubot/handlers/command.py b/maubot/handlers/command.py index f850d12..28f3c78 100644 --- a/maubot/handlers/command.py +++ b/maubot/handlers/command.py @@ -14,10 +14,11 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List, - Dict, Tuple) + Dict, Tuple, Set) from abc import ABC, abstractmethod import asyncio import functools +import inspect import re from mautrix.types import MessageType, EventType @@ -26,6 +27,7 @@ from ..matrix import MaubotMessageEvent from . import event PrefixType = Optional[Union[str, Callable[[], str]]] +AliasesType = Union[List[str], Tuple[str, ...], Set[str], Callable[[str], bool]] CommandHandlerFunc = NewType("CommandHandlerFunc", Callable[[MaubotMessageEvent, Any], Awaitable[Any]]) CommandHandlerDecorator = NewType("CommandHandlerDecorator", @@ -35,28 +37,40 @@ PassiveCommandHandlerDecorator = NewType("PassiveCommandHandlerDecorator", Callable[[CommandHandlerFunc], CommandHandlerFunc]) +def _split_in_two(val: str, split_by: str) -> List[str]: + return val.split(split_by, 1) if split_by in val else [val, ""] + + class CommandHandler: def __init__(self, func: CommandHandlerFunc) -> None: self.__mb_func__: CommandHandlerFunc = func - self.__mb_subcommands__: Dict[str, CommandHandler] = {} + self.__mb_parent__: CommandHandler = None + self.__mb_subcommands__: List[CommandHandler] = [] self.__mb_arguments__: List[Argument] = [] self.__mb_help__: str = None - self.__mb_name__: str = None - self.__mb_prefix__: str = None + self.__mb_get_name__: Callable[[], str] = None + self.__mb_is_command_match__: Callable[[Any, str], bool] = self.__command_match_unset self.__mb_require_subcommand__: bool = True self.__mb_arg_fallthrough__: bool = True self.__mb_event_handler__: bool = True self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE self.__class_instance: Any = None + @staticmethod + def __command_match_unset(self, val: str) -> str: + raise NotImplementedError("Hmm") + async def __call__(self, evt: MaubotMessageEvent, *, _existing_args: Dict[str, Any] = None, - _remaining_val: str = None) -> Any: - body = evt.content.body - has_prefix = _remaining_val or body.startswith(self.__mb_prefix__) - if evt.sender == evt.client.mxid or not has_prefix: + remaining_val: str = None) -> Any: + if evt.sender == evt.client.mxid: return + if remaining_val is None: + if not evt.content.body or evt.content.body[0] != "!": + return + command, remaining_val = _split_in_two(evt.content.body[1:], " ") + if not self.__mb_is_command_match__(self, command): + return call_args: Dict[str, Any] = {**_existing_args} if _existing_args else {} - remaining_val = _remaining_val or body[len(self.__mb_prefix__) + 1:] if not self.__mb_arg_fallthrough__ and len(self.__mb_subcommands__) > 0: ok, res = await self.__call_subcommand__(evt, call_args, remaining_val) @@ -80,14 +94,12 @@ class CommandHandler: async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any], remaining_val: str) -> Tuple[bool, Any]: - remaining_val = remaining_val.strip() - split = remaining_val.split(" ") if len(remaining_val) > 0 else [] - try: - subcommand = self.__mb_subcommands__[split[0]] - return True, await subcommand(evt, _existing_args=call_args, - _remaining_val=" ".join(split[1:])) - except (KeyError, IndexError): - return False, None + command, remaining_val = _split_in_two(remaining_val.strip(), " ") + for subcommand in self.__mb_subcommands__: + if subcommand.__mb_is_command_match__(subcommand.__class_instance, command): + return True, await subcommand(evt, _existing_args=call_args, + remaining_val=remaining_val) + return False, None async def __parse_args__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any], remaining_val: str) -> Tuple[bool, str]: @@ -111,7 +123,7 @@ class CommandHandler: @property def __mb_full_help__(self) -> str: usage = self.__mb_usage_without_subcommands__ + "\n\n" - usage += "\n".join(cmd.__mb_usage_inline__ for cmd in self.__mb_subcommands__.values()) + usage += "\n".join(cmd.__mb_usage_inline__ for cmd in self.__mb_subcommands__) return usage @property @@ -126,6 +138,16 @@ class CommandHandler: def __mb_usage_subcommand__(self) -> str: return f" [...]" + @property + def __mb_name__(self) -> str: + return self.__mb_get_name__(self.__class_instance) + + @property + def __mb_prefix__(self) -> str: + if self.__mb_parent__: + return f"{self.__mb_parent__.__mb_prefix__} {self.__mb_name__}" + return f"!{self.__mb_name__}" + @property def __mb_usage_inline__(self) -> str: if not self.__mb_arg_fallthrough__: @@ -150,31 +172,53 @@ class CommandHandler: return f"{self.__mb_usage_without_subcommands__} \n{self.__mb_subcommands_list__}" return self.__mb_usage_without_subcommands__ - def subcommand(self, name: PrefixType = None, help: str = None + def subcommand(self, name: PrefixType = None, *, help: str = None, aliases: AliasesType = None, + required_subcommand: bool = True, arg_fallthrough: bool = True, ) -> CommandHandlerDecorator: def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler: if not isinstance(func, CommandHandler): func = CommandHandler(func) - func.__mb_name__ = name or func.__name__ - func.__mb_prefix__ = f"{self.__mb_prefix__} {func.__mb_name__}" - func.__mb_help__ = help + new(name, help=help, aliases=aliases, require_subcommand=required_subcommand, + arg_fallthrough=arg_fallthrough)(func) + func.__mb_parent__ = self func.__mb_event_handler__ = False - self.__mb_subcommands__[func.__mb_name__] = func + self.__mb_subcommands__.append(func) return func return decorator -def new(name: PrefixType, *, help: str = None, event_type: EventType = EventType.ROOM_MESSAGE, - require_subcommand: bool = True, arg_fallthrough: bool = True) -> CommandHandlerDecorator: +def new(name: PrefixType = None, *, help: str = None, aliases: AliasesType = None, + event_type: EventType = EventType.ROOM_MESSAGE, require_subcommand: bool = True, + arg_fallthrough: bool = True) -> CommandHandlerDecorator: def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler: if not isinstance(func, CommandHandler): func = CommandHandler(func) func.__mb_help__ = help - func.__mb_name__ = name or func.__name__ + if name: + if callable(name): + if len(inspect.getfullargspec(name).args) == 0: + func.__mb_get_name__ = lambda self: name() + else: + func.__mb_get_name__ = name + else: + func.__mb_get_name__ = lambda self: name + else: + func.__mb_get_name__ = lambda self: func.__name__ + if callable(aliases): + if len(inspect.getfullargspec(aliases).args) == 1: + func.__mb_is_command_match__ = lambda self, val: aliases(val) + else: + func.__mb_is_command_match__ = aliases + elif isinstance(aliases, (list, set, tuple)): + func.__mb_is_command_match__ = lambda self, val: (val == func.__mb_name__ + or val in aliases) + else: + func.__mb_is_command_match__ = lambda self, val: val == func.__mb_name__ + # Decorators are executed last to first, so we reverse the argument list. + func.__mb_arguments__.reverse() func.__mb_require_subcommand__ = require_subcommand func.__mb_arg_fallthrough__ = arg_fallthrough - func.__mb_prefix__ = f"!{func.__mb_name__}" func.__mb_event_type__ = event_type return func