mirror of
https://github.com/maubot/maubot.git
synced 2024-10-01 01:06:10 -04:00
Make new command handling system fully work
This commit is contained in:
parent
5ff5eae3c6
commit
0cf06f9f6b
@ -13,7 +13,10 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List, Dict
|
from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List,
|
||||||
|
Dict, Tuple)
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@ -41,42 +44,65 @@ class CommandHandler:
|
|||||||
self.__mb_name__: str = None
|
self.__mb_name__: str = None
|
||||||
self.__mb_prefix__: str = None
|
self.__mb_prefix__: str = None
|
||||||
self.__mb_require_subcommand__: bool = True
|
self.__mb_require_subcommand__: bool = True
|
||||||
|
self.__mb_arg_fallthrough__: bool = True
|
||||||
self.__mb_event_handler__: bool = True
|
self.__mb_event_handler__: bool = True
|
||||||
self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE
|
self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE
|
||||||
self.__class_instance: Any = None
|
self.__class_instance: Any = None
|
||||||
|
|
||||||
async def __call__(self, evt: MaubotMessageEvent, *,
|
async def __call__(self, evt: MaubotMessageEvent, *, _existing_args: Dict[str, Any] = None,
|
||||||
_existing_args: Dict[str, Any] = None) -> Any:
|
_remaining_val: str = None) -> Any:
|
||||||
body = evt.content.body
|
body = evt.content.body
|
||||||
if evt.sender == evt.client.mxid or not body.startswith(self.__mb_prefix__):
|
has_prefix = _remaining_val or body.startswith(self.__mb_prefix__)
|
||||||
|
if evt.sender == evt.client.mxid or not has_prefix:
|
||||||
return
|
return
|
||||||
call_args: Dict[str, Any] = {**_existing_args} if _existing_args else {}
|
call_args: Dict[str, Any] = {**_existing_args} if _existing_args else {}
|
||||||
remaining_val = body[len(self.__mb_prefix__) + 1:]
|
remaining_val = _remaining_val or body[len(self.__mb_prefix__) + 1:]
|
||||||
# TODO update remaining_val somehow
|
|
||||||
|
if not self.__mb_arg_fallthrough__ and len(self.__mb_subcommands__) > 0:
|
||||||
|
ok, res = await self.__call_subcommand__(evt, call_args, remaining_val)
|
||||||
|
if ok:
|
||||||
|
return res
|
||||||
|
|
||||||
|
ok, remaining_val = await self.__parse_args__(evt, call_args, remaining_val)
|
||||||
|
if not ok:
|
||||||
|
return
|
||||||
|
elif self.__mb_arg_fallthrough__ and len(self.__mb_subcommands__) > 0:
|
||||||
|
ok, res = await self.__call_subcommand__(evt, call_args, remaining_val)
|
||||||
|
if ok:
|
||||||
|
return res
|
||||||
|
elif self.__mb_require_subcommand__:
|
||||||
|
await evt.reply(self.__mb_full_help__)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.__class_instance:
|
||||||
|
return await self.__mb_func__(self.__class_instance, evt, **call_args)
|
||||||
|
return await self.__mb_func__(evt, **call_args)
|
||||||
|
|
||||||
|
async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
|
||||||
|
remaining_val: str) -> Tuple[bool, Any]:
|
||||||
|
remaining_val = remaining_val.strip()
|
||||||
|
split = remaining_val.split(" ") if len(remaining_val) > 0 else []
|
||||||
|
try:
|
||||||
|
subcommand = self.__mb_subcommands__[split[0]]
|
||||||
|
return True, await subcommand(evt, _existing_args=call_args,
|
||||||
|
_remaining_val=" ".join(split[1:]))
|
||||||
|
except (KeyError, IndexError):
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
async def __parse_args__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
|
||||||
|
remaining_val: str) -> Tuple[bool, str]:
|
||||||
for arg in self.__mb_arguments__:
|
for arg in self.__mb_arguments__:
|
||||||
try:
|
try:
|
||||||
call_args[arg.name] = arg.match(remaining_val)
|
remaining_val, call_args[arg.name] = arg.match(remaining_val.strip())
|
||||||
if arg.required and not call_args[arg.name]:
|
if arg.required and not call_args[arg.name]:
|
||||||
raise ValueError("Argument required")
|
raise ValueError("Argument required")
|
||||||
except ArgumentSyntaxError as e:
|
except ArgumentSyntaxError as e:
|
||||||
await evt.reply(e.message + (f"\n{self.__mb_usage__}" if e.show_usage else ""))
|
await evt.reply(e.message + (f"\n{self.__mb_usage__}" if e.show_usage else ""))
|
||||||
return
|
return False, remaining_val
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
await evt.reply(self.__mb_usage__)
|
await evt.reply(self.__mb_usage__)
|
||||||
return
|
return False, remaining_val
|
||||||
|
return True, remaining_val
|
||||||
if len(self.__mb_subcommands__) > 0:
|
|
||||||
split = remaining_val.split(" ") if len(remaining_val) > 0 else []
|
|
||||||
try:
|
|
||||||
subcommand = self.__mb_subcommands__[split[0]]
|
|
||||||
return await subcommand(evt, _existing_args=call_args)
|
|
||||||
except (KeyError, IndexError):
|
|
||||||
if self.__mb_require_subcommand__:
|
|
||||||
await evt.reply(self.__mb_full_help__)
|
|
||||||
return
|
|
||||||
return (await self.__mb_func__(self.__class_instance, evt, **call_args)
|
|
||||||
if self.__class_instance
|
|
||||||
else await self.__mb_func__(evt, **call_args))
|
|
||||||
|
|
||||||
def __get__(self, instance, instancetype):
|
def __get__(self, instance, instancetype):
|
||||||
self.__class_instance = instance
|
self.__class_instance = instance
|
||||||
@ -84,20 +110,45 @@ class CommandHandler:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def __mb_full_help__(self) -> str:
|
def __mb_full_help__(self) -> str:
|
||||||
basic = self.__mb_usage__
|
usage = self.__mb_usage_without_subcommands__ + "\n\n"
|
||||||
usage = f"{basic} <subcommand> [...]\n\n"
|
usage += "\n".join(cmd.__mb_usage_inline__ for cmd in self.__mb_subcommands__.values())
|
||||||
usage += "\n".join(f"* {cmd.__mb_name__} {cmd.__mb_usage_args__} - {cmd.__mb_help__}"
|
|
||||||
for cmd in self.__mb_subcommands__.values())
|
|
||||||
return usage
|
return usage
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def __mb_usage_args__(self) -> str:
|
def __mb_usage_args__(self) -> str:
|
||||||
return " ".join(f"<{arg.label}>" if arg.required else f"[{arg.label}]"
|
arg_usage = " ".join(f"<{arg.label}>" if arg.required else f"[{arg.label}]"
|
||||||
for arg in self.__mb_arguments__)
|
for arg in self.__mb_arguments__)
|
||||||
|
if self.__mb_subcommands__ and self.__mb_arg_fallthrough__:
|
||||||
|
arg_usage += " " + self.__mb_usage_subcommand__
|
||||||
|
return arg_usage
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __mb_usage_subcommand__(self) -> str:
|
||||||
|
return f"<subcommand> [...]"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __mb_usage_inline__(self) -> str:
|
||||||
|
if not self.__mb_arg_fallthrough__:
|
||||||
|
return (f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}\n"
|
||||||
|
f"* {self.__mb_name__} {self.__mb_usage_subcommand__}")
|
||||||
|
return f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __mb_subcommands_list__(self) -> str:
|
||||||
|
return f"**Subcommands:** {', '.join(self.__mb_subcommands__.keys())}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __mb_usage_without_subcommands__(self) -> str:
|
||||||
|
if not self.__mb_arg_fallthrough__:
|
||||||
|
return (f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
|
||||||
|
f" _OR_ {self.__mb_usage_subcommand__}")
|
||||||
|
return f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def __mb_usage__(self) -> str:
|
def __mb_usage__(self) -> str:
|
||||||
return f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
|
if len(self.__mb_subcommands__) > 0:
|
||||||
|
return f"{self.__mb_usage_without_subcommands__} \n{self.__mb_subcommands_list__}"
|
||||||
|
return self.__mb_usage_without_subcommands__
|
||||||
|
|
||||||
def subcommand(self, name: PrefixType = None, help: str = None
|
def subcommand(self, name: PrefixType = None, help: str = None
|
||||||
) -> CommandHandlerDecorator:
|
) -> CommandHandlerDecorator:
|
||||||
@ -114,6 +165,22 @@ class CommandHandler:
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def new(name: PrefixType, *, help: str = None, event_type: EventType = EventType.ROOM_MESSAGE,
|
||||||
|
require_subcommand: bool = True, arg_fallthrough: bool = True) -> CommandHandlerDecorator:
|
||||||
|
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
|
||||||
|
if not isinstance(func, CommandHandler):
|
||||||
|
func = CommandHandler(func)
|
||||||
|
func.__mb_help__ = help
|
||||||
|
func.__mb_name__ = name or func.__name__
|
||||||
|
func.__mb_require_subcommand__ = require_subcommand
|
||||||
|
func.__mb_arg_fallthrough__ = arg_fallthrough
|
||||||
|
func.__mb_prefix__ = f"!{func.__mb_name__}"
|
||||||
|
func.__mb_event_type__ = event_type
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
class ArgumentSyntaxError(ValueError):
|
class ArgumentSyntaxError(ValueError):
|
||||||
def __init__(self, message: str, show_usage: bool = True) -> None:
|
def __init__(self, message: str, show_usage: bool = True) -> None:
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
@ -121,36 +188,17 @@ class ArgumentSyntaxError(ValueError):
|
|||||||
self.show_usage = show_usage
|
self.show_usage = show_usage
|
||||||
|
|
||||||
|
|
||||||
class Argument:
|
class Argument(ABC):
|
||||||
def __init__(self, name: str, label: str = None, *, required: bool = False,
|
def __init__(self, name: str, label: str = None, *, required: bool = False,
|
||||||
matches: Optional[str] = None, parser: Optional[Callable[[str], Any]] = None,
|
|
||||||
pass_raw: bool = False) -> None:
|
pass_raw: bool = False) -> None:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.required = required
|
|
||||||
self.label = label or name
|
self.label = label or name
|
||||||
|
self.required = required
|
||||||
|
self.pass_raw = pass_raw
|
||||||
|
|
||||||
if not parser:
|
@abstractmethod
|
||||||
if matches:
|
def match(self, val: str) -> Tuple[str, Any]:
|
||||||
regex = re.compile(matches)
|
pass
|
||||||
|
|
||||||
def parser(val: str) -> Optional[Sequence[str]]:
|
|
||||||
match = regex.match(val)
|
|
||||||
return match.groups() if match else None
|
|
||||||
else:
|
|
||||||
def parser(val: str) -> str:
|
|
||||||
return val
|
|
||||||
|
|
||||||
if not pass_raw:
|
|
||||||
o_parser = parser
|
|
||||||
|
|
||||||
def parser(val: str) -> Any:
|
|
||||||
val = val.strip().split(" ")
|
|
||||||
return o_parser(val[0])
|
|
||||||
|
|
||||||
self.parser = parser
|
|
||||||
|
|
||||||
def match(self, val: str) -> Any:
|
|
||||||
return self.parser(val)
|
|
||||||
|
|
||||||
def __call__(self, func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
|
def __call__(self, func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
|
||||||
if not isinstance(func, CommandHandler):
|
if not isinstance(func, CommandHandler):
|
||||||
@ -159,49 +207,108 @@ class Argument:
|
|||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
def new(name: PrefixType, *, help: str = None, event_type: EventType = EventType.ROOM_MESSAGE,
|
class RegexArgument(Argument):
|
||||||
require_subcommand: bool = True) -> CommandHandlerDecorator:
|
def __init__(self, name: str, label: str = None, *, required: bool = False,
|
||||||
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
|
pass_raw: bool = False, matches: str = None) -> None:
|
||||||
if not isinstance(func, CommandHandler):
|
super().__init__(name, label, required=required, pass_raw=pass_raw)
|
||||||
func = CommandHandler(func)
|
matches = f"^{matches}" if self.pass_raw else f"^{matches}$"
|
||||||
func.__mb_help__ = help
|
self.regex = re.compile(matches)
|
||||||
func.__mb_name__ = name or func.__name__
|
|
||||||
func.__mb_require_subcommand__ = require_subcommand
|
|
||||||
func.__mb_prefix__ = f"!{func.__mb_name__}"
|
|
||||||
func.__mb_event_type__ = event_type
|
|
||||||
return func
|
|
||||||
|
|
||||||
return decorator
|
def match(self, val: str) -> Tuple[str, Any]:
|
||||||
|
orig_val = val
|
||||||
|
if not self.pass_raw:
|
||||||
|
val = val.split(" ")[0]
|
||||||
|
match = self.regex.match(val)
|
||||||
|
if match:
|
||||||
|
return (orig_val[:match.pos] + orig_val[match.endpos:],
|
||||||
|
match.groups() or val[match.pos:match.endpos])
|
||||||
|
return orig_val, None
|
||||||
|
|
||||||
|
|
||||||
|
class CustomArgument(Argument):
|
||||||
|
def __init__(self, name: str, label: str = None, *, required: bool = False,
|
||||||
|
pass_raw: bool = False, matcher: Callable[[str], Any]) -> None:
|
||||||
|
super().__init__(name, label, required=required, pass_raw=pass_raw)
|
||||||
|
self.matcher = matcher
|
||||||
|
|
||||||
|
def match(self, val: str) -> Tuple[str, Any]:
|
||||||
|
if self.pass_raw:
|
||||||
|
return self.matcher(val)
|
||||||
|
orig_val = val
|
||||||
|
val = val.split(" ")[0]
|
||||||
|
res = self.matcher(val)
|
||||||
|
if res:
|
||||||
|
return orig_val[len(val):], res
|
||||||
|
return orig_val, None
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleArgument(Argument):
|
||||||
|
def match(self, val: str) -> Tuple[str, Any]:
|
||||||
|
if self.pass_raw:
|
||||||
|
return "", val
|
||||||
|
res = val.split(" ")[0]
|
||||||
|
return val[len(res):], res
|
||||||
|
|
||||||
|
|
||||||
def argument(name: str, label: str = None, *, required: bool = True, matches: Optional[str] = None,
|
def argument(name: str, label: str = None, *, required: bool = True, matches: Optional[str] = None,
|
||||||
parser: Optional[Callable[[str], Any]] = None) -> CommandHandlerDecorator:
|
parser: Optional[Callable[[str], Any]] = None, pass_raw: bool = False
|
||||||
return Argument(name, label, required=required, matches=matches, parser=parser)
|
) -> CommandHandlerDecorator:
|
||||||
|
if matches:
|
||||||
|
return RegexArgument(name, label, required=required, matches=matches, pass_raw=pass_raw)
|
||||||
|
elif parser:
|
||||||
|
return CustomArgument(name, label, required=required, matcher=parser, pass_raw=pass_raw)
|
||||||
|
else:
|
||||||
|
return SimpleArgument(name, label, required=required, pass_raw=pass_raw)
|
||||||
|
|
||||||
|
|
||||||
def passive(regex: Union[str, Pattern], msgtypes: Sequence[MessageType] = (MessageType.TEXT,),
|
def passive(regex: Union[str, Pattern], *, msgtypes: Sequence[MessageType] = (MessageType.TEXT,),
|
||||||
field: Callable[[MaubotMessageEvent], str] = lambda event: event.content.body,
|
field: Callable[[MaubotMessageEvent], str] = lambda evt: evt.content.body,
|
||||||
event_type: EventType = EventType.ROOM_MESSAGE) -> PassiveCommandHandlerDecorator:
|
event_type: EventType = EventType.ROOM_MESSAGE, multiple: bool = False
|
||||||
|
) -> PassiveCommandHandlerDecorator:
|
||||||
if not isinstance(regex, Pattern):
|
if not isinstance(regex, Pattern):
|
||||||
regex = re.compile(regex)
|
regex = re.compile(regex)
|
||||||
|
|
||||||
def decorator(func: CommandHandlerFunc) -> CommandHandlerFunc:
|
def decorator(func: CommandHandlerFunc) -> CommandHandlerFunc:
|
||||||
|
combine = None
|
||||||
|
if hasattr(func, "__mb_passive_orig__"):
|
||||||
|
combine = func
|
||||||
|
func = func.__mb_passive_orig__
|
||||||
|
|
||||||
@event.on(event_type)
|
@event.on(event_type)
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def replacement(self, evt: MaubotMessageEvent) -> None:
|
async def replacement(self, evt: MaubotMessageEvent = None) -> None:
|
||||||
if isinstance(self, MaubotMessageEvent):
|
if not evt and isinstance(self, MaubotMessageEvent):
|
||||||
evt = self
|
evt = self
|
||||||
self = None
|
self = None
|
||||||
if evt.sender == evt.client.mxid:
|
if evt.sender == evt.client.mxid:
|
||||||
return
|
return
|
||||||
elif msgtypes and evt.content.msgtype not in msgtypes:
|
elif msgtypes and evt.content.msgtype not in msgtypes:
|
||||||
return
|
return
|
||||||
match = regex.match(field(evt))
|
data = field(evt)
|
||||||
if match:
|
if multiple:
|
||||||
if self:
|
val = [(data[match.pos:match.endpos], *match.groups())
|
||||||
await func(self, evt, *list(match.groups()))
|
for match in regex.finditer(data)]
|
||||||
else:
|
else:
|
||||||
await func(evt, *list(match.groups()))
|
match = regex.match(data)
|
||||||
|
if match:
|
||||||
|
val = (data[match.pos:match.endpos], *match.groups())
|
||||||
|
else:
|
||||||
|
val = None
|
||||||
|
if val:
|
||||||
|
if self:
|
||||||
|
await func(self, evt, val)
|
||||||
|
else:
|
||||||
|
await func(evt, val)
|
||||||
|
|
||||||
|
if combine:
|
||||||
|
orig_replacement = replacement
|
||||||
|
|
||||||
|
@event.on(event_type)
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def replacement(self, evt: MaubotMessageEvent = None) -> None:
|
||||||
|
await asyncio.gather(combine(self, evt), orig_replacement(self, evt))
|
||||||
|
|
||||||
|
replacement.__mb_passive_orig__ = func
|
||||||
|
|
||||||
return replacement
|
return replacement
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user