Make new command handling system fully work

This commit is contained in:
Tulir Asokan 2018-12-25 00:37:02 +02:00
parent 5ff5eae3c6
commit 0cf06f9f6b

View File

@ -13,7 +13,10 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List, Dict
from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List,
Dict, Tuple)
from abc import ABC, abstractmethod
import asyncio
import functools
import re
@ -41,42 +44,65 @@ class CommandHandler:
self.__mb_name__: str = None
self.__mb_prefix__: str = None
self.__mb_require_subcommand__: bool = True
self.__mb_arg_fallthrough__: bool = True
self.__mb_event_handler__: bool = True
self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE
self.__class_instance: Any = None
async def __call__(self, evt: MaubotMessageEvent, *,
_existing_args: Dict[str, Any] = None) -> Any:
async def __call__(self, evt: MaubotMessageEvent, *, _existing_args: Dict[str, Any] = None,
_remaining_val: str = None) -> Any:
body = evt.content.body
if evt.sender == evt.client.mxid or not body.startswith(self.__mb_prefix__):
has_prefix = _remaining_val or body.startswith(self.__mb_prefix__)
if evt.sender == evt.client.mxid or not has_prefix:
return
call_args: Dict[str, Any] = {**_existing_args} if _existing_args else {}
remaining_val = body[len(self.__mb_prefix__) + 1:]
# TODO update remaining_val somehow
remaining_val = _remaining_val or body[len(self.__mb_prefix__) + 1:]
if not self.__mb_arg_fallthrough__ and len(self.__mb_subcommands__) > 0:
ok, res = await self.__call_subcommand__(evt, call_args, remaining_val)
if ok:
return res
ok, remaining_val = await self.__parse_args__(evt, call_args, remaining_val)
if not ok:
return
elif self.__mb_arg_fallthrough__ and len(self.__mb_subcommands__) > 0:
ok, res = await self.__call_subcommand__(evt, call_args, remaining_val)
if ok:
return res
elif self.__mb_require_subcommand__:
await evt.reply(self.__mb_full_help__)
return
if self.__class_instance:
return await self.__mb_func__(self.__class_instance, evt, **call_args)
return await self.__mb_func__(evt, **call_args)
async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
remaining_val: str) -> Tuple[bool, Any]:
remaining_val = remaining_val.strip()
split = remaining_val.split(" ") if len(remaining_val) > 0 else []
try:
subcommand = self.__mb_subcommands__[split[0]]
return True, await subcommand(evt, _existing_args=call_args,
_remaining_val=" ".join(split[1:]))
except (KeyError, IndexError):
return False, None
async def __parse_args__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
remaining_val: str) -> Tuple[bool, str]:
for arg in self.__mb_arguments__:
try:
call_args[arg.name] = arg.match(remaining_val)
remaining_val, call_args[arg.name] = arg.match(remaining_val.strip())
if arg.required and not call_args[arg.name]:
raise ValueError("Argument required")
except ArgumentSyntaxError as e:
await evt.reply(e.message + (f"\n{self.__mb_usage__}" if e.show_usage else ""))
return
return False, remaining_val
except ValueError as e:
await evt.reply(self.__mb_usage__)
return
if len(self.__mb_subcommands__) > 0:
split = remaining_val.split(" ") if len(remaining_val) > 0 else []
try:
subcommand = self.__mb_subcommands__[split[0]]
return await subcommand(evt, _existing_args=call_args)
except (KeyError, IndexError):
if self.__mb_require_subcommand__:
await evt.reply(self.__mb_full_help__)
return
return (await self.__mb_func__(self.__class_instance, evt, **call_args)
if self.__class_instance
else await self.__mb_func__(evt, **call_args))
return False, remaining_val
return True, remaining_val
def __get__(self, instance, instancetype):
self.__class_instance = instance
@ -84,20 +110,45 @@ class CommandHandler:
@property
def __mb_full_help__(self) -> str:
basic = self.__mb_usage__
usage = f"{basic} <subcommand> [...]\n\n"
usage += "\n".join(f"* {cmd.__mb_name__} {cmd.__mb_usage_args__} - {cmd.__mb_help__}"
for cmd in self.__mb_subcommands__.values())
usage = self.__mb_usage_without_subcommands__ + "\n\n"
usage += "\n".join(cmd.__mb_usage_inline__ for cmd in self.__mb_subcommands__.values())
return usage
@property
def __mb_usage_args__(self) -> str:
return " ".join(f"<{arg.label}>" if arg.required else f"[{arg.label}]"
arg_usage = " ".join(f"<{arg.label}>" if arg.required else f"[{arg.label}]"
for arg in self.__mb_arguments__)
if self.__mb_subcommands__ and self.__mb_arg_fallthrough__:
arg_usage += " " + self.__mb_usage_subcommand__
return arg_usage
@property
def __mb_usage_subcommand__(self) -> str:
return f"<subcommand> [...]"
@property
def __mb_usage_inline__(self) -> str:
if not self.__mb_arg_fallthrough__:
return (f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}\n"
f"* {self.__mb_name__} {self.__mb_usage_subcommand__}")
return f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}"
@property
def __mb_subcommands_list__(self) -> str:
return f"**Subcommands:** {', '.join(self.__mb_subcommands__.keys())}"
@property
def __mb_usage_without_subcommands__(self) -> str:
if not self.__mb_arg_fallthrough__:
return (f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
f" _OR_ {self.__mb_usage_subcommand__}")
return f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
@property
def __mb_usage__(self) -> str:
return f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
if len(self.__mb_subcommands__) > 0:
return f"{self.__mb_usage_without_subcommands__} \n{self.__mb_subcommands_list__}"
return self.__mb_usage_without_subcommands__
def subcommand(self, name: PrefixType = None, help: str = None
) -> CommandHandlerDecorator:
@ -114,6 +165,22 @@ class CommandHandler:
return decorator
def new(name: PrefixType, *, help: str = None, event_type: EventType = EventType.ROOM_MESSAGE,
require_subcommand: bool = True, arg_fallthrough: bool = True) -> CommandHandlerDecorator:
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
if not isinstance(func, CommandHandler):
func = CommandHandler(func)
func.__mb_help__ = help
func.__mb_name__ = name or func.__name__
func.__mb_require_subcommand__ = require_subcommand
func.__mb_arg_fallthrough__ = arg_fallthrough
func.__mb_prefix__ = f"!{func.__mb_name__}"
func.__mb_event_type__ = event_type
return func
return decorator
class ArgumentSyntaxError(ValueError):
def __init__(self, message: str, show_usage: bool = True) -> None:
super().__init__(message)
@ -121,36 +188,17 @@ class ArgumentSyntaxError(ValueError):
self.show_usage = show_usage
class Argument:
class Argument(ABC):
def __init__(self, name: str, label: str = None, *, required: bool = False,
matches: Optional[str] = None, parser: Optional[Callable[[str], Any]] = None,
pass_raw: bool = False) -> None:
self.name = name
self.required = required
self.label = label or name
self.required = required
self.pass_raw = pass_raw
if not parser:
if matches:
regex = re.compile(matches)
def parser(val: str) -> Optional[Sequence[str]]:
match = regex.match(val)
return match.groups() if match else None
else:
def parser(val: str) -> str:
return val
if not pass_raw:
o_parser = parser
def parser(val: str) -> Any:
val = val.strip().split(" ")
return o_parser(val[0])
self.parser = parser
def match(self, val: str) -> Any:
return self.parser(val)
@abstractmethod
def match(self, val: str) -> Tuple[str, Any]:
pass
def __call__(self, func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
if not isinstance(func, CommandHandler):
@ -159,49 +207,108 @@ class Argument:
return func
def new(name: PrefixType, *, help: str = None, event_type: EventType = EventType.ROOM_MESSAGE,
require_subcommand: bool = True) -> CommandHandlerDecorator:
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
if not isinstance(func, CommandHandler):
func = CommandHandler(func)
func.__mb_help__ = help
func.__mb_name__ = name or func.__name__
func.__mb_require_subcommand__ = require_subcommand
func.__mb_prefix__ = f"!{func.__mb_name__}"
func.__mb_event_type__ = event_type
return func
class RegexArgument(Argument):
def __init__(self, name: str, label: str = None, *, required: bool = False,
pass_raw: bool = False, matches: str = None) -> None:
super().__init__(name, label, required=required, pass_raw=pass_raw)
matches = f"^{matches}" if self.pass_raw else f"^{matches}$"
self.regex = re.compile(matches)
return decorator
def match(self, val: str) -> Tuple[str, Any]:
orig_val = val
if not self.pass_raw:
val = val.split(" ")[0]
match = self.regex.match(val)
if match:
return (orig_val[:match.pos] + orig_val[match.endpos:],
match.groups() or val[match.pos:match.endpos])
return orig_val, None
class CustomArgument(Argument):
def __init__(self, name: str, label: str = None, *, required: bool = False,
pass_raw: bool = False, matcher: Callable[[str], Any]) -> None:
super().__init__(name, label, required=required, pass_raw=pass_raw)
self.matcher = matcher
def match(self, val: str) -> Tuple[str, Any]:
if self.pass_raw:
return self.matcher(val)
orig_val = val
val = val.split(" ")[0]
res = self.matcher(val)
if res:
return orig_val[len(val):], res
return orig_val, None
class SimpleArgument(Argument):
def match(self, val: str) -> Tuple[str, Any]:
if self.pass_raw:
return "", val
res = val.split(" ")[0]
return val[len(res):], res
def argument(name: str, label: str = None, *, required: bool = True, matches: Optional[str] = None,
parser: Optional[Callable[[str], Any]] = None) -> CommandHandlerDecorator:
return Argument(name, label, required=required, matches=matches, parser=parser)
parser: Optional[Callable[[str], Any]] = None, pass_raw: bool = False
) -> CommandHandlerDecorator:
if matches:
return RegexArgument(name, label, required=required, matches=matches, pass_raw=pass_raw)
elif parser:
return CustomArgument(name, label, required=required, matcher=parser, pass_raw=pass_raw)
else:
return SimpleArgument(name, label, required=required, pass_raw=pass_raw)
def passive(regex: Union[str, Pattern], msgtypes: Sequence[MessageType] = (MessageType.TEXT,),
field: Callable[[MaubotMessageEvent], str] = lambda event: event.content.body,
event_type: EventType = EventType.ROOM_MESSAGE) -> PassiveCommandHandlerDecorator:
def passive(regex: Union[str, Pattern], *, msgtypes: Sequence[MessageType] = (MessageType.TEXT,),
field: Callable[[MaubotMessageEvent], str] = lambda evt: evt.content.body,
event_type: EventType = EventType.ROOM_MESSAGE, multiple: bool = False
) -> PassiveCommandHandlerDecorator:
if not isinstance(regex, Pattern):
regex = re.compile(regex)
def decorator(func: CommandHandlerFunc) -> CommandHandlerFunc:
combine = None
if hasattr(func, "__mb_passive_orig__"):
combine = func
func = func.__mb_passive_orig__
@event.on(event_type)
@functools.wraps(func)
async def replacement(self, evt: MaubotMessageEvent) -> None:
if isinstance(self, MaubotMessageEvent):
async def replacement(self, evt: MaubotMessageEvent = None) -> None:
if not evt and isinstance(self, MaubotMessageEvent):
evt = self
self = None
if evt.sender == evt.client.mxid:
return
elif msgtypes and evt.content.msgtype not in msgtypes:
return
match = regex.match(field(evt))
if match:
if self:
await func(self, evt, *list(match.groups()))
data = field(evt)
if multiple:
val = [(data[match.pos:match.endpos], *match.groups())
for match in regex.finditer(data)]
else:
await func(evt, *list(match.groups()))
match = regex.match(data)
if match:
val = (data[match.pos:match.endpos], *match.groups())
else:
val = None
if val:
if self:
await func(self, evt, val)
else:
await func(evt, val)
if combine:
orig_replacement = replacement
@event.on(event_type)
@functools.wraps(func)
async def replacement(self, evt: MaubotMessageEvent = None) -> None:
await asyncio.gather(combine(self, evt), orig_replacement(self, evt))
replacement.__mb_passive_orig__ = func
return replacement