diff --git a/maubot/__main__.py b/maubot/__main__.py index f91b30d..aaae853 100644 --- a/maubot/__main__.py +++ b/maubot/__main__.py @@ -20,13 +20,13 @@ import argparse import asyncio import copy import sys -import os from .config import Config from .db import Base, init as init_db from .server import MaubotServer from .client import Client, init as init_client -from .loader import ZippedPluginLoader, MaubotZipImportError +from .loader import ZippedPluginLoader +from .plugin import PluginInstance from .__meta__ import __version__ parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.", @@ -57,27 +57,22 @@ loop = asyncio.get_event_loop() init_db(db_session) init_client(loop) server = MaubotServer(config, loop) +ZippedPluginLoader.load_all(*config["plugin_directories"]) +plugins = PluginInstance.all() -loader_log = logging.getLogger("maubot.loader.zip") -loader_log.debug("Preloading plugins...") -for directory in config["plugin_directories"]: - for file in os.listdir(directory): - if not file.endswith(".mbp"): - continue - path = os.path.join(directory, file) - try: - loader = ZippedPluginLoader.get(path) - loader_log.debug(f"Preloaded plugin {loader.id} from {loader.path}.") - except MaubotZipImportError: - loader_log.exception(f"Failed to load plugin at {path}.") +for plugin in plugins: + plugin.load() try: - loop.run_until_complete(server.start()) + loop.run_until_complete(asyncio.gather( + server.start(), + *[plugin.start() for plugin in plugins])) log.debug("Startup actions complete, running forever.") loop.run_forever() except KeyboardInterrupt: log.debug("Keyboard interrupt received, stopping...") for client in Client.cache.values(): client.stop() + db_session.commit() loop.run_until_complete(server.stop()) sys.exit(0) diff --git a/maubot/client.py b/maubot/client.py index f4ed9c3..80b947a 100644 --- a/maubot/client.py +++ b/maubot/client.py @@ -37,15 +37,21 @@ class Client: def __init__(self, db_instance: DBClient) -> None: self.db_instance = db_instance self.cache[self.id] = self - self.client = MaubotMatrixClient(maubot_client=self, store=self.db_instance, - mxid=self.id, base_url=self.homeserver, + self.log = log.getChild(self.id) + self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver, token=self.access_token, client_session=self.http_client, - log=log.getChild(self.id)) + log=self.log, loop=self.loop, store=self.db_instance) if self.autojoin: self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER) def start(self) -> None: - asyncio.ensure_future(self.client.start(), loop=self.loop) + asyncio.ensure_future(self._start(), loop=self.loop) + + async def _start(self) -> None: + try: + await self.client.start() + except Exception: + self.log.exception("Fail") def stop(self) -> None: self.client.stop() @@ -64,6 +70,10 @@ class Client: def all(cls) -> List['Client']: return [cls.get(user.id, user) for user in DBClient.query.all()] + async def _handle_invite(self, evt: StateEvent) -> None: + if evt.state_key == self.id and evt.content.membership == Membership.INVITE: + await self.client.join_room_by_id(evt.room_id) + # region Properties @property @@ -72,7 +82,7 @@ class Client: @property def homeserver(self) -> str: - return self.db_instance.id + return self.db_instance.homeserver @property def access_token(self) -> str: @@ -139,12 +149,9 @@ class Client: # endregion - async def _handle_invite(self, evt: StateEvent) -> None: - if evt.state_key == self.id and evt.content.membership == Membership.INVITE: - await self.client.join_room_by_id(evt.room_id) - def init(loop: asyncio.AbstractEventLoop) -> None: + Client.http_client = ClientSession(loop=loop) Client.loop = loop for client in Client.all(): client.start() diff --git a/maubot/lib/zipimport.py b/maubot/lib/zipimport.py index 0ef16cf..f9a0ca7 100644 --- a/maubot/lib/zipimport.py +++ b/maubot/lib/zipimport.py @@ -118,6 +118,12 @@ class zipimporter: self._files = _read_directory(self.archive) _zip_directory_cache[self.archive] = self._files + def remove_cache(self): + try: + del _zip_directory_cache[self.archive] + except KeyError: + pass + # Check whether we can satisfy the import of the module named by # 'fullname', or whether it could be a portion of a namespace # package. Return self if we can load it, a string containing the diff --git a/maubot/loader/abc.py b/maubot/loader/abc.py index 92d01f6..20b593a 100644 --- a/maubot/loader/abc.py +++ b/maubot/loader/abc.py @@ -13,11 +13,14 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import TypeVar, Type, Dict +from typing import TypeVar, Type, Dict, Set, TYPE_CHECKING from abc import ABC, abstractmethod from ..plugin_base import Plugin +if TYPE_CHECKING: + from ..plugin import PluginInstance + PluginClass = TypeVar("PluginClass", bound=Plugin) @@ -28,9 +31,17 @@ class IDConflictError(Exception): class PluginLoader(ABC): id_cache: Dict[str, 'PluginLoader'] = {} + references: Set['PluginInstance'] id: str version: str + def __init__(self): + self.references = set() + + @classmethod + def find(cls, plugin_id: str) -> 'PluginLoader': + return cls.id_cache[plugin_id] + @property @abstractmethod def source(self) -> str: diff --git a/maubot/loader/zip.py b/maubot/loader/zip.py index f6cf645..57f341d 100644 --- a/maubot/loader/zip.py +++ b/maubot/loader/zip.py @@ -15,8 +15,10 @@ # along with this program. If not, see . from typing import Dict, List, Type from zipfile import ZipFile, BadZipFile -import sys import configparser +import logging +import sys +import os from ..lib.zipimport import zipimporter, ZipImportError from ..plugin_base import Plugin @@ -29,6 +31,7 @@ class MaubotZipImportError(Exception): class ZippedPluginLoader(PluginLoader): path_cache: Dict[str, 'ZippedPluginLoader'] = {} + log = logging.getLogger("maubot.loader.zip") path: str id: str @@ -40,9 +43,11 @@ class ZippedPluginLoader(PluginLoader): _importer: zipimporter def __init__(self, path: str) -> None: + super().__init__() self.path = path self.id = None self._loaded = None + self._importer = None self._load_meta() self._run_preload_checks(self._get_importer()) try: @@ -52,6 +57,7 @@ class ZippedPluginLoader(PluginLoader): pass self.path_cache[self.path] = self self.id_cache[self.id] = self + self.log.debug(f"Preloaded plugin {self.id} from {self.path}") @classmethod def get(cls, path: str) -> 'ZippedPluginLoader': @@ -68,7 +74,7 @@ class ZippedPluginLoader(PluginLoader): return ("") + f"loaded={self._loaded is not None}>") def _load_meta(self) -> None: try: @@ -100,10 +106,11 @@ class ZippedPluginLoader(PluginLoader): def _get_importer(self, reset_cache: bool = False) -> zipimporter: try: - importer = zipimporter(self.path) + if not self._importer: + self._importer = zipimporter(self.path) if reset_cache: - importer.reset_cache() - return importer + self._importer.reset_cache() + return self._importer except ZipImportError as e: raise MaubotZipImportError("File not found or not a maubot plugin") from e @@ -127,6 +134,8 @@ class ZippedPluginLoader(PluginLoader): return self._loaded importer = self._get_importer(reset_cache=reset_cache) self._run_preload_checks(importer) + if reset_cache: + self.log.debug(f"Preloaded plugin {self.id} from {self.path}") for module in self.modules: importer.load_module(module) main_mod = sys.modules[self.main_module] @@ -134,6 +143,7 @@ class ZippedPluginLoader(PluginLoader): if not issubclass(plugin, Plugin): raise MaubotZipImportError("Main class of plugin does not extend maubot.Plugin") self._loaded = plugin + self.log.debug(f"Loaded and imported plugin {self.id} from {self.path}") return plugin def reload(self) -> Type[PluginClass]: @@ -144,6 +154,8 @@ class ZippedPluginLoader(PluginLoader): for name, mod in list(sys.modules.items()): if getattr(mod, "__file__", "").startswith(self.path): del sys.modules[name] + self._loaded = None + self.log.debug(f"Unloaded plugin {self.id} at {self.path}") def destroy(self) -> None: self.unload() @@ -155,3 +167,24 @@ class ZippedPluginLoader(PluginLoader): del self.id_cache[self.id] except KeyError: pass + self.id = None + self.path = None + self.version = None + self.modules = None + if self._importer: + self._importer.remove_cache() + self._importer = None + self._loaded = None + + @classmethod + def load_all(cls, *args: str) -> None: + cls.log.debug("Preloading plugins...") + for directory in args: + for file in os.listdir(directory): + if not file.endswith(".mbp"): + continue + path = os.path.join(directory, file) + try: + ZippedPluginLoader.get(path) + except (MaubotZipImportError, IDConflictError): + cls.log.exception(f"Failed to load plugin at {path}") diff --git a/maubot/plugin.py b/maubot/plugin.py index 36823a0..850a45e 100644 --- a/maubot/plugin.py +++ b/maubot/plugin.py @@ -13,12 +13,15 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, List +from typing import Dict, List, Optional import logging from mautrix.types import UserID from .db import DBPlugin +from .client import Client +from .loader import PluginLoader +from .plugin_base import Plugin log = logging.getLogger("maubot.plugin") @@ -27,10 +30,56 @@ class PluginInstance: cache: Dict[str, 'PluginInstance'] = {} plugin_directories: List[str] = [] + log: logging.Logger + loader: PluginLoader + client: Client + plugin: Plugin + def __init__(self, db_instance: DBPlugin): self.db_instance = db_instance + self.log = logging.getLogger(f"maubot.plugin.{self.id}") self.cache[self.id] = self + def load(self) -> None: + try: + self.loader = PluginLoader.find(self.type) + except KeyError: + self.log.error(f"Failed to find loader for type {self.type}") + self.db_instance.enabled = False + return + self.client = Client.get(self.primary_user) + if not self.client: + self.log.error(f"Failed to get client for user {self.primary_user}") + self.db_instance.enabled = False + + async def start(self) -> None: + self.log.debug(f"Starting...") + cls = self.loader.load() + self.plugin = cls(self.client.client, self.id, self.log) + self.loader.references |= {self} + await self.plugin.start() + + async def stop(self) -> None: + self.log.debug("Stopping...") + self.loader.references -= {self} + await self.plugin.stop() + self.plugin = None + + @classmethod + def get(cls, instance_id: str, db_instance: Optional[DBPlugin] = None + ) -> Optional['PluginInstance']: + try: + return cls.cache[instance_id] + except KeyError: + db_instance = db_instance or DBPlugin.query.get(instance_id) + if not db_instance: + return None + return PluginInstance(db_instance) + + @classmethod + def all(cls) -> List['PluginInstance']: + return [cls.get(plugin.id, plugin) for plugin in DBPlugin.query.all()] + # region Properties @property diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py index 69dedf5..58f3435 100644 --- a/maubot/plugin_base.py +++ b/maubot/plugin_base.py @@ -14,6 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from typing import TYPE_CHECKING +from logging import Logger from abc import ABC if TYPE_CHECKING: @@ -22,9 +23,14 @@ if TYPE_CHECKING: class Plugin(ABC): - def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str) -> None: + client: 'MaubotMatrixClient' + id: str + log: Logger + + def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str, log: Logger) -> None: self.client = client self.id = plugin_instance_id + self.log = log def set_command_spec(self, spec: 'CommandSpec') -> None: self.client.set_command_spec(self.id, spec)