From eef052b1e90f2d2b3950be386d53ef1e25e3e1c0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Oct 2018 16:41:02 +0300 Subject: [PATCH] More changes --- example-config.yaml | 5 +-- maubot/__main__.py | 23 ++++++++++++- maubot/client.py | 78 ++++++++++++++----------------------------- maubot/db.py | 8 ++++- maubot/loader/abc.py | 5 ++- maubot/loader/zip.py | 1 - maubot/matrix.py | 77 ++++++++++++++++++++++++++++++++++++++++++ maubot/plugin_base.py | 5 +-- maubot/server.py | 54 ++++++++++++++++++++++++++++++ 9 files changed, 195 insertions(+), 61 deletions(-) create mode 100644 maubot/matrix.py create mode 100644 maubot/server.py diff --git a/example-config.yaml b/example-config.yaml index a951ae9..3920e3e 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -10,8 +10,9 @@ plugin_directories: - ./plugins server: - # The IP:port to listen to. - listen: 0.0.0.0:29316 + # The IP and port to listen to. + hostname: 0.0.0.0 + port: 29316 # The base management API path. base_path: /_matrix/maubot # The base appservice API path. Use / for legacy appservice API and /_matrix/app/v1 for v1. diff --git a/maubot/__main__.py b/maubot/__main__.py index 2b97589..76dc72a 100644 --- a/maubot/__main__.py +++ b/maubot/__main__.py @@ -17,9 +17,14 @@ from sqlalchemy import orm import sqlalchemy as sql import logging.config import argparse +import asyncio import copy +import sys from .config import Config +from .db import Base, init as init_db +from .server import MaubotServer +from .client import Client, init as init_client from .__meta__ import __version__ parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.", @@ -36,7 +41,23 @@ logging.config.dictConfig(copy.deepcopy(config["logging"])) log = logging.getLogger("maubot") log.debug(f"Initializing maubot {__version__}") -db_engine = sql.create_engine(config["database"]) +db_engine: sql.engine.Engine = sql.create_engine(config["database"]) db_factory = orm.sessionmaker(bind=db_engine) db_session = orm.scoping.scoped_session(db_factory) Base.metadata.bind=db_engine + +loop = asyncio.get_event_loop() + +init_db(db_session) +init_client(loop) +server = MaubotServer(config, loop) + +try: + loop.run_until_complete(server.start()) + loop.run_forever() +except KeyboardInterrupt: + log.debug("Keyboard interrupt received, stopping...") + for client in Client.cache.values(): + client.stop() + loop.run_until_complete(server.stop()) + sys.exit(0) diff --git a/maubot/client.py b/maubot/client.py index 30cccff..f4ed9c3 100644 --- a/maubot/client.py +++ b/maubot/client.py @@ -13,62 +13,21 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, List, Optional, Union, Callable +from typing import Dict, List, Optional from aiohttp import ClientSession import asyncio import logging -from mautrix import Client as MatrixClient -from mautrix.client import EventHandler -from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership, - EventType, MessageEvent) +from mautrix.types import UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership, EventType -from .command_spec import ParsedCommand from .db import DBClient +from .matrix import MaubotMatrixClient log = logging.getLogger("maubot.client") -class MaubotMatrixClient(MatrixClient): - def __init__(self, maubot_client: 'Client', *args, **kwargs): - super().__init__(*args, **kwargs) - self._maubot_client = maubot_client - self.command_handlers: Dict[str, List[EventHandler]] = {} - self.commands: List[ParsedCommand] = [] - - self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE) - - async def _command_event_handler(self, evt: MessageEvent) -> None: - for command in self.commands: - if command.match(evt): - await self._trigger_command(command, evt) - return - - async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None: - for handler in self.command_handlers.get(command.name, []): - await handler(evt) - - def on(self, var: Union[EventHandler, EventType, str] - ) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]: - if isinstance(var, str): - def decorator(func: EventHandler) -> EventHandler: - self.add_command_handler(var, func) - return func - - return decorator - return super().on(var) - - def add_command_handler(self, command: str, handler: EventHandler) -> None: - self.command_handlers.setdefault(command, []).append(handler) - - def remove_command_handler(self, command: str, handler: EventHandler) -> None: - try: - self.command_handlers[command].remove(handler) - except (KeyError, ValueError): - pass - - class Client: + loop: asyncio.AbstractEventLoop cache: Dict[UserID, 'Client'] = {} http_client: ClientSession = None @@ -78,26 +37,33 @@ class Client: def __init__(self, db_instance: DBClient) -> None: self.db_instance = db_instance self.cache[self.id] = self - self.client = MaubotMatrixClient(maubot_client=self, - store=self.db_instance, - mxid=self.id, - base_url=self.homeserver, - token=self.access_token, - client_session=self.http_client, + self.client = MaubotMatrixClient(maubot_client=self, store=self.db_instance, + mxid=self.id, base_url=self.homeserver, + token=self.access_token, client_session=self.http_client, log=log.getChild(self.id)) if self.autojoin: self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER) + def start(self) -> None: + asyncio.ensure_future(self.client.start(), loop=self.loop) + + def stop(self) -> None: + self.client.stop() + @classmethod - def get(cls, user_id: UserID) -> Optional['Client']: + def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']: try: return cls.cache[user_id] except KeyError: - db_instance = DBClient.query.get(user_id) + db_instance = db_instance or DBClient.query.get(user_id) if not db_instance: return None return Client(db_instance) + @classmethod + def all(cls) -> List['Client']: + return [cls.get(user.id, user) for user in DBClient.query.all()] + # region Properties @property @@ -176,3 +142,9 @@ class Client: async def _handle_invite(self, evt: StateEvent) -> None: if evt.state_key == self.id and evt.content.membership == Membership.INVITE: await self.client.join_room_by_id(evt.room_id) + + +def init(loop: asyncio.AbstractEventLoop) -> None: + Client.loop = loop + for client in Client.all(): + client.start() diff --git a/maubot/db.py b/maubot/db.py index 9c4ccc1..fd4e4cc 100644 --- a/maubot/db.py +++ b/maubot/db.py @@ -15,7 +15,7 @@ # along with this program. If not, see . from typing import Type from sqlalchemy import (Column, String, Boolean, ForeignKey, Text, TypeDecorator) -from sqlalchemy.orm import Query +from sqlalchemy.orm import Query, scoped_session from sqlalchemy.ext.declarative import declarative_base import json @@ -89,3 +89,9 @@ class DBCommandSpec(Base): ForeignKey("client.id", onupdate="CASCADE", ondelete="CASCADE"), primary_key=True) spec: CommandSpec = Column(make_serializable_alchemy(CommandSpec), nullable=False) + + +def init(session: scoped_session) -> None: + DBPlugin.query = session.query_property() + DBClient.query = session.query_property() + DBCommandSpec.query = session.query_property() diff --git a/maubot/loader/abc.py b/maubot/loader/abc.py index e7b323d..71c6ce3 100644 --- a/maubot/loader/abc.py +++ b/maubot/loader/abc.py @@ -13,14 +13,17 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import TypeVar, Type +from typing import TypeVar, Type, Dict from abc import ABC, abstractmethod + from ..plugin_base import Plugin PluginClass = TypeVar("PluginClass", bound=Plugin) class PluginLoader(ABC): + id_cache: Dict[str, 'PluginLoader'] = {} + id: str version: str diff --git a/maubot/loader/zip.py b/maubot/loader/zip.py index bb44979..a69ccec 100644 --- a/maubot/loader/zip.py +++ b/maubot/loader/zip.py @@ -29,7 +29,6 @@ class MaubotZipImportError(Exception): class ZippedPluginLoader(PluginLoader): path_cache: Dict[str, 'ZippedPluginLoader'] = {} - id_cache: Dict[str, 'ZippedPluginLoader'] = {} path: str id: str diff --git a/maubot/matrix.py b/maubot/matrix.py new file mode 100644 index 0000000..11b6eba --- /dev/null +++ b/maubot/matrix.py @@ -0,0 +1,77 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2018 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from typing import Dict, List, Union, Callable + +from mautrix import Client as MatrixClient +from mautrix.client import EventHandler +from mautrix.types import EventType, MessageEvent + +from .command_spec import ParsedCommand, CommandSpec + + +class MaubotMatrixClient(MatrixClient): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.command_handlers: Dict[str, List[EventHandler]] = {} + self.commands: List[ParsedCommand] = [] + self.command_specs: Dict[str, CommandSpec] = {} + + self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE) + + def set_command_spec(self, plugin_id: str, spec: CommandSpec) -> None: + self.command_specs[plugin_id] = spec + self._reparse_command_specs() + + def _reparse_command_specs(self) -> None: + self.commands = [parsed_command + for spec in self.command_specs.values() + for parsed_command in spec.parse()] + + def remove_command_spec(self, plugin_id: str) -> None: + try: + del self.command_specs[plugin_id] + self._reparse_command_specs() + except KeyError: + pass + + async def _command_event_handler(self, evt: MessageEvent) -> None: + for command in self.commands: + if command.match(evt): + await self._trigger_command(command, evt) + return + + async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None: + for handler in self.command_handlers.get(command.name, []): + await handler(evt) + + def on(self, var: Union[EventHandler, EventType, str] + ) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]: + if isinstance(var, str): + def decorator(func: EventHandler) -> EventHandler: + self.add_command_handler(var, func) + return func + + return decorator + return super().on(var) + + def add_command_handler(self, command: str, handler: EventHandler) -> None: + self.command_handlers.setdefault(command, []).append(handler) + + def remove_command_handler(self, command: str, handler: EventHandler) -> None: + try: + self.command_handlers[command].remove(handler) + except (KeyError, ValueError): + pass diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py index 0f22c00..69dedf5 100644 --- a/maubot/plugin_base.py +++ b/maubot/plugin_base.py @@ -22,11 +22,12 @@ if TYPE_CHECKING: class Plugin(ABC): - def __init__(self, client: 'MaubotMatrixClient') -> None: + def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str) -> None: self.client = client + self.id = plugin_instance_id def set_command_spec(self, spec: 'CommandSpec') -> None: - pass + self.client.set_command_spec(self.id, spec) async def start(self) -> None: pass diff --git a/maubot/server.py b/maubot/server.py new file mode 100644 index 0000000..55a4b4a --- /dev/null +++ b/maubot/server.py @@ -0,0 +1,54 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2018 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from aiohttp import web +import asyncio + +from mautrix.api import PathBuilder + +from .config import Config +from .__meta__ import __version__ + + +class MaubotServer: + def __init__(self, config: Config, loop: asyncio.AbstractEventLoop): + self.loop = loop or asyncio.get_event_loop() + self.app = web.Application(loop=self.loop) + self.config = config + + path = PathBuilder(config["server.base_path"]) + self.app.router.add_get(path.version, self.version) + + as_path = PathBuilder(config["server.appservice_base_path"]) + self.app.router.add_put(as_path.transactions, self.handle_transaction) + + self.runner = web.AppRunner(self.app) + + async def start(self) -> None: + await self.runner.setup() + site = web.TCPSite(self.runner, self.config["server.hostname"], self.config["server.port"]) + await site.start() + + async def stop(self) -> None: + await self.runner.cleanup() + + @staticmethod + async def version(_: web.Request) -> web.Response: + return web.json_response({ + "version": __version__ + }) + + async def handle_transaction(self, request: web.Request) -> web.Response: + return web.Response(status=501)