diff --git a/maubot/loader/zip.py b/maubot/loader/zip.py index 48fafce..d18894f 100644 --- a/maubot/loader/zip.py +++ b/maubot/loader/zip.py @@ -30,6 +30,18 @@ class MaubotZipImportError(Exception): pass +class MaubotZipMetaError(MaubotZipImportError): + pass + + +class MaubotZipPreLoadError(MaubotZipImportError): + pass + + +class MaubotZipLoadError(MaubotZipImportError): + pass + + class ZippedPluginLoader(PluginLoader): path_cache: Dict[str, 'ZippedPluginLoader'] = {} log: logging.Logger = logging.getLogger("maubot.loader.zip") @@ -96,16 +108,16 @@ class ZippedPluginLoader(PluginLoader): file = ZipFile(source) data = file.read("maubot.ini") except FileNotFoundError as e: - raise MaubotZipImportError("Maubot plugin not found") from e + raise MaubotZipMetaError("Maubot plugin not found") from e except BadZipFile as e: - raise MaubotZipImportError("File is not a maubot plugin") from e + raise MaubotZipMetaError("File is not a maubot plugin") from e except KeyError as e: - raise MaubotZipImportError("File does not contain a maubot plugin definition") from e + raise MaubotZipMetaError("File does not contain a maubot plugin definition") from e config = configparser.ConfigParser() try: config.read_string(data.decode("utf-8")) except (configparser.Error, KeyError, IndexError, ValueError) as e: - raise MaubotZipImportError("Maubot plugin definition in file is invalid") from e + raise MaubotZipMetaError("Maubot plugin definition in file is invalid") from e return file, config @classmethod @@ -120,7 +132,7 @@ class ZippedPluginLoader(PluginLoader): if "/" in main_class: main_module, main_class = main_class.split("/")[:2] except (configparser.Error, KeyError, IndexError, ValueError) as e: - raise MaubotZipImportError("Maubot plugin definition in file is invalid") from e + raise MaubotZipMetaError("Maubot plugin definition in file is invalid") from e return meta_id, version, modules, main_class, main_module @classmethod @@ -133,7 +145,7 @@ class ZippedPluginLoader(PluginLoader): file, config = self._open_meta(self.path) meta = self._read_meta(config) if self.id and meta[0] != self.id: - raise MaubotZipImportError("Maubot plugin ID changed during reload") + raise MaubotZipMetaError("Maubot plugin ID changed during reload") self.id, self.version, self.modules, self.main_class, self.main_module = meta self._file = file @@ -145,22 +157,22 @@ class ZippedPluginLoader(PluginLoader): self._importer.reset_cache() return self._importer except ZipImportError as e: - raise MaubotZipImportError("File not found or not a maubot plugin") from e + raise MaubotZipMetaError("File not found or not a maubot plugin") from e def _run_preload_checks(self, importer: zipimporter) -> None: try: code = importer.get_code(self.main_module.replace(".", "/")) if self.main_class not in code.co_names: - raise MaubotZipImportError( + raise MaubotZipPreLoadError( f"Main class {self.main_class} not in {self.main_module}") except ZipImportError as e: - raise MaubotZipImportError( + raise MaubotZipPreLoadError( f"Main module {self.main_module} not found in file") from e for module in self.modules: try: importer.find_module(module) except ZipImportError as e: - raise MaubotZipImportError(f"Module {module} not found in file") from e + raise MaubotZipPreLoadError(f"Module {module} not found in file") from e async def load(self, reset_cache: bool = False) -> Type[PluginClass]: try: @@ -175,13 +187,22 @@ class ZippedPluginLoader(PluginLoader): importer = self._get_importer(reset_cache=reset_cache) self._run_preload_checks(importer) if reset_cache: - self.log.debug(f"Preloaded plugin {self.id} from {self.path}") + self.log.debug(f"Re-preloaded plugin {self.id} from {self.path}") for module in self.modules: - importer.load_module(module) - main_mod = sys.modules[self.main_module] - plugin = getattr(main_mod, self.main_class) + try: + importer.load_module(module) + except ZipImportError as e: + raise MaubotZipLoadError(f"Module {module} not found in file") + try: + main_mod = sys.modules[self.main_module] + except KeyError as e: + raise MaubotZipLoadError(f"Main module {self.main_module} of plugin not found") from e + try: + plugin = getattr(main_mod, self.main_class) + except AttributeError as e: + raise MaubotZipLoadError(f"Main class {self.main_class} of plugin not found") from e if not issubclass(plugin, Plugin): - raise MaubotZipImportError("Main class of plugin does not extend maubot.Plugin") + raise MaubotZipLoadError("Main class of plugin does not extend maubot.Plugin") self._loaded = plugin self.log.debug(f"Loaded and imported plugin {self.id} from {self.path}") return plugin diff --git a/maubot/management/api/plugin.py b/maubot/management/api/plugin.py index 6e67c06..7345bb3 100644 --- a/maubot/management/api/plugin.py +++ b/maubot/management/api/plugin.py @@ -15,11 +15,12 @@ # along with this program. If not, see . from aiohttp import web from io import BytesIO +import traceback import os.path from ...loader import PluginLoader, ZippedPluginLoader, MaubotZipImportError -from .responses import (ErrPluginNotFound, ErrPluginInUse, ErrInputPluginInvalid, - ErrPluginReloadFailed, RespDeleted, RespOK) +from .responses import (ErrPluginNotFound, ErrPluginInUse, plugin_import_error, + plugin_reload_error, RespDeleted, RespOK, ErrUnsupportedPluginLoader) from . import routes, config @@ -62,7 +63,55 @@ async def reload_plugin(request: web.Request) -> web.Response: plugin = PluginLoader.id_cache.get(plugin_id, None) if not plugin: return ErrPluginNotFound - return await reload(plugin) + + await plugin.stop_instances() + try: + await plugin.reload() + except MaubotZipImportError as e: + return plugin_reload_error(str(e), traceback.format_exc()) + await plugin.start_instances() + return RespOK + + +async def upload_new_plugin(content: bytes, pid: str, version: str) -> web.Response: + path = os.path.join(config["plugin_directories.upload"], f"{pid}-v{version}.mbp") + with open(path, "wb") as p: + p.write(content) + try: + ZippedPluginLoader.get(path) + except MaubotZipImportError as e: + ZippedPluginLoader.trash(path) + return plugin_import_error(str(e), traceback.format_exc()) + return RespOK + + +async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes, new_version: str + ) -> web.Response: + dirname = os.path.dirname(plugin.path) + filename = os.path.basename(plugin.path) + if plugin.version in filename: + filename = filename.replace(plugin.version, new_version) + else: + filename = filename.rstrip(".mbp") + new_version + ".mbp" + path = os.path.join(dirname, filename) + with open(path, "wb") as p: + p.write(content) + old_path = plugin.path + plugin.path = path + await plugin.stop_instances() + try: + await plugin.reload() + except MaubotZipImportError as e: + plugin.path = old_path + try: + await plugin.reload() + except MaubotZipImportError: + pass + await plugin.start_instances() + return plugin_import_error(str(e), traceback.format_exc()) + await plugin.start_instances() + ZippedPluginLoader.trash(plugin.path, reason="update") + return RespOK @routes.post("/plugins/upload") @@ -72,40 +121,11 @@ async def upload_plugin(request: web.Request) -> web.Response: try: pid, version = ZippedPluginLoader.verify_meta(file) except MaubotZipImportError as e: - return ErrInputPluginInvalid(e) + return plugin_import_error(str(e), traceback.format_exc()) plugin = PluginLoader.id_cache.get(pid, None) if not plugin: - path = os.path.join(config["plugin_directories.upload"], f"{pid}-v{version}.mbp") - with open(path, "wb") as p: - p.write(content) - try: - ZippedPluginLoader.get(path) - except MaubotZipImportError as e: - ZippedPluginLoader.trash(path) - # TODO log error? - return ErrInputPluginInvalid(e) + return await upload_new_plugin(content, pid, version) elif isinstance(plugin, ZippedPluginLoader): - dirname = os.path.dirname(plugin.path) - filename = os.path.basename(plugin.path) - if plugin.version in filename: - filename = filename.replace(plugin.version, version) - else: - filename = filename.rstrip(".mbp") + version + ".mbp" - path = os.path.join(dirname, filename) - with open(path, "wb") as p: - p.write(content) - ZippedPluginLoader.trash(plugin.path, reason="update") - plugin.path = path - return await reload(plugin) + return await upload_replacement_plugin(plugin, content, version) else: - return web.json_response({}) - - -async def reload(plugin: PluginLoader) -> web.Response: - await plugin.stop_instances() - try: - await plugin.reload() - except MaubotZipImportError as e: - return ErrPluginReloadFailed(e) - await plugin.start_instances() - return RespOK + return ErrUnsupportedPluginLoader diff --git a/maubot/management/api/responses.py b/maubot/management/api/responses.py index 993a354..d328cc2 100644 --- a/maubot/management/api/responses.py +++ b/maubot/management/api/responses.py @@ -36,20 +36,27 @@ ErrPluginInUse = web.json_response({ }, status=web.HTTPPreconditionFailed) -def ErrInputPluginInvalid(error) -> web.Response: +def plugin_import_error(error: str, stacktrace: str) -> web.Response: return web.json_response({ - "error": str(error), + "error": error, + "stacktrace": stacktrace, "errcode": "plugin_invalid", }, status=web.HTTPBadRequest) -def ErrPluginReloadFailed(error) -> web.Response: +def plugin_reload_error(error: str, stacktrace: str) -> web.Response: return web.json_response({ - "error": str(error), - "errcode": "plugin_invalid", + "error": error, + "stacktrace": stacktrace, + "errcode": "plugin_reload_fail", }, status=web.HTTPInternalServerError) +ErrUnsupportedPluginLoader = web.json_response({ + "error": "Existing plugin with same ID uses unsupported plugin loader", + "errcode": "unsupported_plugin_loader", +}, status=web.HTTPBadRequest) + ErrNotImplemented = web.json_response({ "error": "Not implemented", "errcode": "not_implemented",