diff --git a/maubot/handlers/command.py b/maubot/handlers/command.py
index f850d12..28f3c78 100644
--- a/maubot/handlers/command.py
+++ b/maubot/handlers/command.py
@@ -14,10 +14,11 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List,
- Dict, Tuple)
+ Dict, Tuple, Set)
from abc import ABC, abstractmethod
import asyncio
import functools
+import inspect
import re
from mautrix.types import MessageType, EventType
@@ -26,6 +27,7 @@ from ..matrix import MaubotMessageEvent
from . import event
PrefixType = Optional[Union[str, Callable[[], str]]]
+AliasesType = Union[List[str], Tuple[str, ...], Set[str], Callable[[str], bool]]
CommandHandlerFunc = NewType("CommandHandlerFunc",
Callable[[MaubotMessageEvent, Any], Awaitable[Any]])
CommandHandlerDecorator = NewType("CommandHandlerDecorator",
@@ -35,28 +37,40 @@ PassiveCommandHandlerDecorator = NewType("PassiveCommandHandlerDecorator",
Callable[[CommandHandlerFunc], CommandHandlerFunc])
+def _split_in_two(val: str, split_by: str) -> List[str]:
+ return val.split(split_by, 1) if split_by in val else [val, ""]
+
+
class CommandHandler:
def __init__(self, func: CommandHandlerFunc) -> None:
self.__mb_func__: CommandHandlerFunc = func
- self.__mb_subcommands__: Dict[str, CommandHandler] = {}
+ self.__mb_parent__: CommandHandler = None
+ self.__mb_subcommands__: List[CommandHandler] = []
self.__mb_arguments__: List[Argument] = []
self.__mb_help__: str = None
- self.__mb_name__: str = None
- self.__mb_prefix__: str = None
+ self.__mb_get_name__: Callable[[], str] = None
+ self.__mb_is_command_match__: Callable[[Any, str], bool] = self.__command_match_unset
self.__mb_require_subcommand__: bool = True
self.__mb_arg_fallthrough__: bool = True
self.__mb_event_handler__: bool = True
self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE
self.__class_instance: Any = None
+ @staticmethod
+ def __command_match_unset(self, val: str) -> str:
+ raise NotImplementedError("Hmm")
+
async def __call__(self, evt: MaubotMessageEvent, *, _existing_args: Dict[str, Any] = None,
- _remaining_val: str = None) -> Any:
- body = evt.content.body
- has_prefix = _remaining_val or body.startswith(self.__mb_prefix__)
- if evt.sender == evt.client.mxid or not has_prefix:
+ remaining_val: str = None) -> Any:
+ if evt.sender == evt.client.mxid:
return
+ if remaining_val is None:
+ if not evt.content.body or evt.content.body[0] != "!":
+ return
+ command, remaining_val = _split_in_two(evt.content.body[1:], " ")
+ if not self.__mb_is_command_match__(self, command):
+ return
call_args: Dict[str, Any] = {**_existing_args} if _existing_args else {}
- remaining_val = _remaining_val or body[len(self.__mb_prefix__) + 1:]
if not self.__mb_arg_fallthrough__ and len(self.__mb_subcommands__) > 0:
ok, res = await self.__call_subcommand__(evt, call_args, remaining_val)
@@ -80,14 +94,12 @@ class CommandHandler:
async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
remaining_val: str) -> Tuple[bool, Any]:
- remaining_val = remaining_val.strip()
- split = remaining_val.split(" ") if len(remaining_val) > 0 else []
- try:
- subcommand = self.__mb_subcommands__[split[0]]
- return True, await subcommand(evt, _existing_args=call_args,
- _remaining_val=" ".join(split[1:]))
- except (KeyError, IndexError):
- return False, None
+ command, remaining_val = _split_in_two(remaining_val.strip(), " ")
+ for subcommand in self.__mb_subcommands__:
+ if subcommand.__mb_is_command_match__(subcommand.__class_instance, command):
+ return True, await subcommand(evt, _existing_args=call_args,
+ remaining_val=remaining_val)
+ return False, None
async def __parse_args__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
remaining_val: str) -> Tuple[bool, str]:
@@ -111,7 +123,7 @@ class CommandHandler:
@property
def __mb_full_help__(self) -> str:
usage = self.__mb_usage_without_subcommands__ + "\n\n"
- usage += "\n".join(cmd.__mb_usage_inline__ for cmd in self.__mb_subcommands__.values())
+ usage += "\n".join(cmd.__mb_usage_inline__ for cmd in self.__mb_subcommands__)
return usage
@property
@@ -126,6 +138,16 @@ class CommandHandler:
def __mb_usage_subcommand__(self) -> str:
return f" [...]"
+ @property
+ def __mb_name__(self) -> str:
+ return self.__mb_get_name__(self.__class_instance)
+
+ @property
+ def __mb_prefix__(self) -> str:
+ if self.__mb_parent__:
+ return f"{self.__mb_parent__.__mb_prefix__} {self.__mb_name__}"
+ return f"!{self.__mb_name__}"
+
@property
def __mb_usage_inline__(self) -> str:
if not self.__mb_arg_fallthrough__:
@@ -150,31 +172,53 @@ class CommandHandler:
return f"{self.__mb_usage_without_subcommands__} \n{self.__mb_subcommands_list__}"
return self.__mb_usage_without_subcommands__
- def subcommand(self, name: PrefixType = None, help: str = None
+ def subcommand(self, name: PrefixType = None, *, help: str = None, aliases: AliasesType = None,
+ required_subcommand: bool = True, arg_fallthrough: bool = True,
) -> CommandHandlerDecorator:
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
if not isinstance(func, CommandHandler):
func = CommandHandler(func)
- func.__mb_name__ = name or func.__name__
- func.__mb_prefix__ = f"{self.__mb_prefix__} {func.__mb_name__}"
- func.__mb_help__ = help
+ new(name, help=help, aliases=aliases, require_subcommand=required_subcommand,
+ arg_fallthrough=arg_fallthrough)(func)
+ func.__mb_parent__ = self
func.__mb_event_handler__ = False
- self.__mb_subcommands__[func.__mb_name__] = func
+ self.__mb_subcommands__.append(func)
return func
return decorator
-def new(name: PrefixType, *, help: str = None, event_type: EventType = EventType.ROOM_MESSAGE,
- require_subcommand: bool = True, arg_fallthrough: bool = True) -> CommandHandlerDecorator:
+def new(name: PrefixType = None, *, help: str = None, aliases: AliasesType = None,
+ event_type: EventType = EventType.ROOM_MESSAGE, require_subcommand: bool = True,
+ arg_fallthrough: bool = True) -> CommandHandlerDecorator:
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
if not isinstance(func, CommandHandler):
func = CommandHandler(func)
func.__mb_help__ = help
- func.__mb_name__ = name or func.__name__
+ if name:
+ if callable(name):
+ if len(inspect.getfullargspec(name).args) == 0:
+ func.__mb_get_name__ = lambda self: name()
+ else:
+ func.__mb_get_name__ = name
+ else:
+ func.__mb_get_name__ = lambda self: name
+ else:
+ func.__mb_get_name__ = lambda self: func.__name__
+ if callable(aliases):
+ if len(inspect.getfullargspec(aliases).args) == 1:
+ func.__mb_is_command_match__ = lambda self, val: aliases(val)
+ else:
+ func.__mb_is_command_match__ = aliases
+ elif isinstance(aliases, (list, set, tuple)):
+ func.__mb_is_command_match__ = lambda self, val: (val == func.__mb_name__
+ or val in aliases)
+ else:
+ func.__mb_is_command_match__ = lambda self, val: val == func.__mb_name__
+ # Decorators are executed last to first, so we reverse the argument list.
+ func.__mb_arguments__.reverse()
func.__mb_require_subcommand__ = require_subcommand
func.__mb_arg_fallthrough__ = arg_fallthrough
- func.__mb_prefix__ = f"!{func.__mb_name__}"
func.__mb_event_type__ = event_type
return func