diff --git a/maubot/__main__.py b/maubot/__main__.py
index f91b30d..aaae853 100644
--- a/maubot/__main__.py
+++ b/maubot/__main__.py
@@ -20,13 +20,13 @@ import argparse
import asyncio
import copy
import sys
-import os
from .config import Config
from .db import Base, init as init_db
from .server import MaubotServer
from .client import Client, init as init_client
-from .loader import ZippedPluginLoader, MaubotZipImportError
+from .loader import ZippedPluginLoader
+from .plugin import PluginInstance
from .__meta__ import __version__
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
@@ -57,27 +57,22 @@ loop = asyncio.get_event_loop()
init_db(db_session)
init_client(loop)
server = MaubotServer(config, loop)
+ZippedPluginLoader.load_all(*config["plugin_directories"])
+plugins = PluginInstance.all()
-loader_log = logging.getLogger("maubot.loader.zip")
-loader_log.debug("Preloading plugins...")
-for directory in config["plugin_directories"]:
- for file in os.listdir(directory):
- if not file.endswith(".mbp"):
- continue
- path = os.path.join(directory, file)
- try:
- loader = ZippedPluginLoader.get(path)
- loader_log.debug(f"Preloaded plugin {loader.id} from {loader.path}.")
- except MaubotZipImportError:
- loader_log.exception(f"Failed to load plugin at {path}.")
+for plugin in plugins:
+ plugin.load()
try:
- loop.run_until_complete(server.start())
+ loop.run_until_complete(asyncio.gather(
+ server.start(),
+ *[plugin.start() for plugin in plugins]))
log.debug("Startup actions complete, running forever.")
loop.run_forever()
except KeyboardInterrupt:
log.debug("Keyboard interrupt received, stopping...")
for client in Client.cache.values():
client.stop()
+ db_session.commit()
loop.run_until_complete(server.stop())
sys.exit(0)
diff --git a/maubot/client.py b/maubot/client.py
index f4ed9c3..80b947a 100644
--- a/maubot/client.py
+++ b/maubot/client.py
@@ -37,15 +37,21 @@ class Client:
def __init__(self, db_instance: DBClient) -> None:
self.db_instance = db_instance
self.cache[self.id] = self
- self.client = MaubotMatrixClient(maubot_client=self, store=self.db_instance,
- mxid=self.id, base_url=self.homeserver,
+ self.log = log.getChild(self.id)
+ self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
token=self.access_token, client_session=self.http_client,
- log=log.getChild(self.id))
+ log=self.log, loop=self.loop, store=self.db_instance)
if self.autojoin:
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
def start(self) -> None:
- asyncio.ensure_future(self.client.start(), loop=self.loop)
+ asyncio.ensure_future(self._start(), loop=self.loop)
+
+ async def _start(self) -> None:
+ try:
+ await self.client.start()
+ except Exception:
+ self.log.exception("Fail")
def stop(self) -> None:
self.client.stop()
@@ -64,6 +70,10 @@ class Client:
def all(cls) -> List['Client']:
return [cls.get(user.id, user) for user in DBClient.query.all()]
+ async def _handle_invite(self, evt: StateEvent) -> None:
+ if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
+ await self.client.join_room_by_id(evt.room_id)
+
# region Properties
@property
@@ -72,7 +82,7 @@ class Client:
@property
def homeserver(self) -> str:
- return self.db_instance.id
+ return self.db_instance.homeserver
@property
def access_token(self) -> str:
@@ -139,12 +149,9 @@ class Client:
# endregion
- async def _handle_invite(self, evt: StateEvent) -> None:
- if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
- await self.client.join_room_by_id(evt.room_id)
-
def init(loop: asyncio.AbstractEventLoop) -> None:
+ Client.http_client = ClientSession(loop=loop)
Client.loop = loop
for client in Client.all():
client.start()
diff --git a/maubot/lib/zipimport.py b/maubot/lib/zipimport.py
index 0ef16cf..f9a0ca7 100644
--- a/maubot/lib/zipimport.py
+++ b/maubot/lib/zipimport.py
@@ -118,6 +118,12 @@ class zipimporter:
self._files = _read_directory(self.archive)
_zip_directory_cache[self.archive] = self._files
+ def remove_cache(self):
+ try:
+ del _zip_directory_cache[self.archive]
+ except KeyError:
+ pass
+
# Check whether we can satisfy the import of the module named by
# 'fullname', or whether it could be a portion of a namespace
# package. Return self if we can load it, a string containing the
diff --git a/maubot/loader/abc.py b/maubot/loader/abc.py
index 92d01f6..20b593a 100644
--- a/maubot/loader/abc.py
+++ b/maubot/loader/abc.py
@@ -13,11 +13,14 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import TypeVar, Type, Dict
+from typing import TypeVar, Type, Dict, Set, TYPE_CHECKING
from abc import ABC, abstractmethod
from ..plugin_base import Plugin
+if TYPE_CHECKING:
+ from ..plugin import PluginInstance
+
PluginClass = TypeVar("PluginClass", bound=Plugin)
@@ -28,9 +31,17 @@ class IDConflictError(Exception):
class PluginLoader(ABC):
id_cache: Dict[str, 'PluginLoader'] = {}
+ references: Set['PluginInstance']
id: str
version: str
+ def __init__(self):
+ self.references = set()
+
+ @classmethod
+ def find(cls, plugin_id: str) -> 'PluginLoader':
+ return cls.id_cache[plugin_id]
+
@property
@abstractmethod
def source(self) -> str:
diff --git a/maubot/loader/zip.py b/maubot/loader/zip.py
index f6cf645..57f341d 100644
--- a/maubot/loader/zip.py
+++ b/maubot/loader/zip.py
@@ -15,8 +15,10 @@
# along with this program. If not, see .
from typing import Dict, List, Type
from zipfile import ZipFile, BadZipFile
-import sys
import configparser
+import logging
+import sys
+import os
from ..lib.zipimport import zipimporter, ZipImportError
from ..plugin_base import Plugin
@@ -29,6 +31,7 @@ class MaubotZipImportError(Exception):
class ZippedPluginLoader(PluginLoader):
path_cache: Dict[str, 'ZippedPluginLoader'] = {}
+ log = logging.getLogger("maubot.loader.zip")
path: str
id: str
@@ -40,9 +43,11 @@ class ZippedPluginLoader(PluginLoader):
_importer: zipimporter
def __init__(self, path: str) -> None:
+ super().__init__()
self.path = path
self.id = None
self._loaded = None
+ self._importer = None
self._load_meta()
self._run_preload_checks(self._get_importer())
try:
@@ -52,6 +57,7 @@ class ZippedPluginLoader(PluginLoader):
pass
self.path_cache[self.path] = self
self.id_cache[self.id] = self
+ self.log.debug(f"Preloaded plugin {self.id} from {self.path}")
@classmethod
def get(cls, path: str) -> 'ZippedPluginLoader':
@@ -68,7 +74,7 @@ class ZippedPluginLoader(PluginLoader):
return ("")
+ f"loaded={self._loaded is not None}>")
def _load_meta(self) -> None:
try:
@@ -100,10 +106,11 @@ class ZippedPluginLoader(PluginLoader):
def _get_importer(self, reset_cache: bool = False) -> zipimporter:
try:
- importer = zipimporter(self.path)
+ if not self._importer:
+ self._importer = zipimporter(self.path)
if reset_cache:
- importer.reset_cache()
- return importer
+ self._importer.reset_cache()
+ return self._importer
except ZipImportError as e:
raise MaubotZipImportError("File not found or not a maubot plugin") from e
@@ -127,6 +134,8 @@ class ZippedPluginLoader(PluginLoader):
return self._loaded
importer = self._get_importer(reset_cache=reset_cache)
self._run_preload_checks(importer)
+ if reset_cache:
+ self.log.debug(f"Preloaded plugin {self.id} from {self.path}")
for module in self.modules:
importer.load_module(module)
main_mod = sys.modules[self.main_module]
@@ -134,6 +143,7 @@ class ZippedPluginLoader(PluginLoader):
if not issubclass(plugin, Plugin):
raise MaubotZipImportError("Main class of plugin does not extend maubot.Plugin")
self._loaded = plugin
+ self.log.debug(f"Loaded and imported plugin {self.id} from {self.path}")
return plugin
def reload(self) -> Type[PluginClass]:
@@ -144,6 +154,8 @@ class ZippedPluginLoader(PluginLoader):
for name, mod in list(sys.modules.items()):
if getattr(mod, "__file__", "").startswith(self.path):
del sys.modules[name]
+ self._loaded = None
+ self.log.debug(f"Unloaded plugin {self.id} at {self.path}")
def destroy(self) -> None:
self.unload()
@@ -155,3 +167,24 @@ class ZippedPluginLoader(PluginLoader):
del self.id_cache[self.id]
except KeyError:
pass
+ self.id = None
+ self.path = None
+ self.version = None
+ self.modules = None
+ if self._importer:
+ self._importer.remove_cache()
+ self._importer = None
+ self._loaded = None
+
+ @classmethod
+ def load_all(cls, *args: str) -> None:
+ cls.log.debug("Preloading plugins...")
+ for directory in args:
+ for file in os.listdir(directory):
+ if not file.endswith(".mbp"):
+ continue
+ path = os.path.join(directory, file)
+ try:
+ ZippedPluginLoader.get(path)
+ except (MaubotZipImportError, IDConflictError):
+ cls.log.exception(f"Failed to load plugin at {path}")
diff --git a/maubot/plugin.py b/maubot/plugin.py
index 36823a0..850a45e 100644
--- a/maubot/plugin.py
+++ b/maubot/plugin.py
@@ -13,12 +13,15 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Dict, List
+from typing import Dict, List, Optional
import logging
from mautrix.types import UserID
from .db import DBPlugin
+from .client import Client
+from .loader import PluginLoader
+from .plugin_base import Plugin
log = logging.getLogger("maubot.plugin")
@@ -27,10 +30,56 @@ class PluginInstance:
cache: Dict[str, 'PluginInstance'] = {}
plugin_directories: List[str] = []
+ log: logging.Logger
+ loader: PluginLoader
+ client: Client
+ plugin: Plugin
+
def __init__(self, db_instance: DBPlugin):
self.db_instance = db_instance
+ self.log = logging.getLogger(f"maubot.plugin.{self.id}")
self.cache[self.id] = self
+ def load(self) -> None:
+ try:
+ self.loader = PluginLoader.find(self.type)
+ except KeyError:
+ self.log.error(f"Failed to find loader for type {self.type}")
+ self.db_instance.enabled = False
+ return
+ self.client = Client.get(self.primary_user)
+ if not self.client:
+ self.log.error(f"Failed to get client for user {self.primary_user}")
+ self.db_instance.enabled = False
+
+ async def start(self) -> None:
+ self.log.debug(f"Starting...")
+ cls = self.loader.load()
+ self.plugin = cls(self.client.client, self.id, self.log)
+ self.loader.references |= {self}
+ await self.plugin.start()
+
+ async def stop(self) -> None:
+ self.log.debug("Stopping...")
+ self.loader.references -= {self}
+ await self.plugin.stop()
+ self.plugin = None
+
+ @classmethod
+ def get(cls, instance_id: str, db_instance: Optional[DBPlugin] = None
+ ) -> Optional['PluginInstance']:
+ try:
+ return cls.cache[instance_id]
+ except KeyError:
+ db_instance = db_instance or DBPlugin.query.get(instance_id)
+ if not db_instance:
+ return None
+ return PluginInstance(db_instance)
+
+ @classmethod
+ def all(cls) -> List['PluginInstance']:
+ return [cls.get(plugin.id, plugin) for plugin in DBPlugin.query.all()]
+
# region Properties
@property
diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py
index 69dedf5..58f3435 100644
--- a/maubot/plugin_base.py
+++ b/maubot/plugin_base.py
@@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
from typing import TYPE_CHECKING
+from logging import Logger
from abc import ABC
if TYPE_CHECKING:
@@ -22,9 +23,14 @@ if TYPE_CHECKING:
class Plugin(ABC):
- def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str) -> None:
+ client: 'MaubotMatrixClient'
+ id: str
+ log: Logger
+
+ def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str, log: Logger) -> None:
self.client = client
self.id = plugin_instance_id
+ self.log = log
def set_command_spec(self, spec: 'CommandSpec') -> None:
self.client.set_command_spec(self.id, spec)