Update type hints and allow InternalEventTypes for @event.on

This commit is contained in:
Tulir Asokan 2019-09-27 00:45:30 +03:00
parent ed6055744f
commit 7fcb7cbf0a
2 changed files with 6 additions and 5 deletions

View File

@ -16,19 +16,20 @@
from typing import Callable, Union, NewType from typing import Callable, Union, NewType
from mautrix.types import EventType from mautrix.types import EventType
from mautrix.client import EventHandler from mautrix.client import EventHandler, InternalEventType
EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler], EventHandler]) EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler], EventHandler])
def on(var: Union[EventType, EventHandler]) -> Union[EventHandlerDecorator, EventHandler]: def on(var: Union[EventType, InternalEventType, EventHandler]
) -> Union[EventHandlerDecorator, EventHandler]:
def decorator(func: EventHandler) -> EventHandler: def decorator(func: EventHandler) -> EventHandler:
func.__mb_event_handler__ = True func.__mb_event_handler__ = True
if isinstance(var, EventType): if isinstance(var, (EventType, InternalEventType)):
func.__mb_event_type__ = var func.__mb_event_type__ = var
else: else:
func.__mb_event_type__ = EventType.ALL func.__mb_event_type__ = EventType.ALL
return func return func
return decorator if isinstance(var, EventType) else decorator(var) return decorator if isinstance(var, (EventType, InternalEventType)) else decorator(var)

View File

@ -73,7 +73,7 @@ class MaubotMessageEvent(MessageEvent):
def mark_read(self) -> Awaitable[None]: def mark_read(self) -> Awaitable[None]:
return self.client.send_receipt(self.room_id, self.event_id, "m.read") return self.client.send_receipt(self.room_id, self.event_id, "m.read")
def react(self, key: str) -> Awaitable[None]: def react(self, key: str) -> Awaitable[EventID]:
return self.client.react(self.room_id, self.event_id, key) return self.client.react(self.room_id, self.event_id, key)
def edit(self, content: Union[str, MessageEventContent], def edit(self, content: Union[str, MessageEventContent],