diff --git a/maubot/instance.py b/maubot/instance.py index af7d86a..114858d 100644 --- a/maubot/instance.py +++ b/maubot/instance.py @@ -35,7 +35,7 @@ from .loader import PluginLoader, ZippedPluginLoader from .plugin_base import Plugin if TYPE_CHECKING: - from .server import MaubotServer + from .server import MaubotServer, PluginWebApp log = logging.getLogger("maubot.instance") @@ -59,7 +59,7 @@ class PluginInstance: base_cfg: RecursiveDict[CommentedMap] inst_db: sql.engine.Engine inst_db_tables: Dict[str, sql.Table] - inst_webapp: web.Application + inst_webapp: 'PluginWebApp' inst_webapp_url: str started: bool diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py index ce2d2f4..e07cae1 100644 --- a/maubot/plugin_base.py +++ b/maubot/plugin_base.py @@ -17,7 +17,6 @@ from typing import Type, Optional, TYPE_CHECKING from abc import ABC from logging import Logger from asyncio import AbstractEventLoop -from aiohttp.web import Application from sqlalchemy.engine.base import Engine from aiohttp import ClientSession @@ -25,6 +24,7 @@ from aiohttp import ClientSession if TYPE_CHECKING: from mautrix.util.config import BaseProxyConfig from .client import MaubotMatrixClient + from .server import PluginWebApp class Plugin(ABC): @@ -34,10 +34,12 @@ class Plugin(ABC): loop: AbstractEventLoop config: Optional['BaseProxyConfig'] database: Optional[Engine] + webapp: Optional['PluginWebApp'] + webapp_url: Optional[str] def __init__(self, client: 'MaubotMatrixClient', loop: AbstractEventLoop, http: ClientSession, instance_id: str, log: Logger, config: Optional['BaseProxyConfig'], - database: Optional[Engine], webapp: Optional[Application], + database: Optional[Engine], webapp: Optional['PluginWebApp'], webapp_url: Optional[str]) -> None: self.client = client self.loop = loop diff --git a/maubot/server.py b/maubot/server.py index 1128fb0..91e113f 100644 --- a/maubot/server.py +++ b/maubot/server.py @@ -13,11 +13,12 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Tuple, Dict +from typing import Tuple, List, Dict, Callable, Awaitable +from functools import partial import logging import asyncio -from aiohttp import web +from aiohttp import web, hdrs, URL from aiohttp.abc import AbstractAccessLogger import pkg_resources @@ -34,6 +35,62 @@ class AccessLogger(AbstractAccessLogger): f'in {round(time, 4)}s"') +Handler = Callable[[web.Request], Awaitable[web.Response]] +Middleware = Callable[[web.Request, Handler], Awaitable[web.Response]] + + +class PluginWebApp(web.UrlDispatcher): + def __init__(self): + super().__init__() + self._middleware: List[Middleware] = [] + + def add_middleware(self, middleware: Middleware) -> None: + self._middleware.append(middleware) + + def remove_middleware(self, middleware: Middleware) -> None: + self._middleware.remove(middleware) + + async def handle(self, request: web.Request) -> web.Response: + match_info = await self.resolve(request) + match_info.freeze() + resp = None + request._match_info = match_info + expect = request.headers.get(hdrs.EXPECT) + if expect: + resp = await match_info.expect_handler(request) + await request.writer.drain() + if resp is None: + handler = match_info.handler + for middleware in self._middleware: + handler = partial(middleware, handler=handler) + resp = await handler(request) + return resp + + +class PrefixResource(web.Resource): + def __init__(self, prefix, *, name=None): + assert not prefix or prefix.startswith('/'), prefix + assert prefix in ('', '/') or not prefix.endswith('/'), prefix + super().__init__(name=name) + self._prefix = URL.build(path=prefix).raw_path + + @property + def canonical(self): + return self._prefix + + def add_prefix(self, prefix): + assert prefix.startswith('/') + assert not prefix.endswith('/') + assert len(prefix) > 1 + self._prefix = prefix + self._prefix + + def _match(self, path: str) -> dict: + return {} if self.raw_match(path) else None + + def raw_match(self, path: str) -> bool: + return path and path.startswith(self._prefix) + + class MaubotServer: log: logging.Logger = logging.getLogger("maubot.server") @@ -45,38 +102,38 @@ class MaubotServer: as_path = PathBuilder(config["server.appservice_base_path"]) self.add_route(Method.PUT, as_path.transactions, self.handle_transaction) - self.plugin_apps: Dict[str, web.Application] = {} - self.app.router.add_view(config["server.plugin_base_path"], self.handle_plugin_path) + self.plugin_routes: Dict[str, PluginWebApp] = {} + resource = PrefixResource(config["server.plugin_base_path"]) + resource.add_route(hdrs.METH_ANY, self.handle_plugin_path) + self.app.router.register_resource(resource) self.setup_management_ui() self.runner = web.AppRunner(self.app, access_log_class=AccessLogger) async def handle_plugin_path(self, request: web.Request) -> web.Response: - for path, app in self.plugin_apps.items(): + for path, app in self.plugin_routes.items(): if request.path.startswith(path): - # TODO there's probably a correct way to do these - request._rel_url.path = request._rel_url.path[len(path):] - return await app._handle(request) + request = request.clone(rel_url=request.path[len(path):]) + return await app.handle(request) return web.Response(status=404) - def get_instance_subapp(self, instance_id: str) -> Tuple[web.Application, str]: - subpath = self.config["server.plugin_base_path"].format(id=instance_id) + def get_instance_subapp(self, instance_id: str) -> Tuple[PluginWebApp, str]: + subpath = self.config["server.plugin_base_path"] + instance_id url = self.config["server.public_url"] + subpath try: - return self.plugin_apps[subpath], url + return self.plugin_routes[subpath], url except KeyError: - app = web.Application(loop=self.loop) - self.plugin_apps[subpath] = app + app = PluginWebApp() + self.plugin_routes[subpath] = app return app, url def remove_instance_webapp(self, instance_id: str) -> None: try: - subapp: web.Application = self.plugin_apps.pop(instance_id) + subpath = self.config["server.plugin_base_path"] + instance_id + self.plugin_routes.pop(subpath) except KeyError: return - subapp.shutdown() - subapp.cleanup() def setup_management_ui(self) -> None: ui_base = self.config["server.ui_base_path"]