diff --git a/maubot/management/api/__init__.py b/maubot/management/api/__init__.py index d3e2611..e66b527 100644 --- a/maubot/management/api/__init__.py +++ b/maubot/management/api/__init__.py @@ -16,34 +16,15 @@ from aiohttp import web from asyncio import AbstractEventLoop -from mautrix.types import UserID -from mautrix.util.signed_token import sign_token, verify_token - from ...config import Config - -routes = web.RouteTableDef() -config: Config = None - - -def is_valid_token(token: str) -> bool: - data = verify_token(config["server.unshared_secret"], token) - if not data: - return False - return config.is_admin(data.get("user_id", None)) - - -def create_token(user: UserID) -> str: - return sign_token(config["server.unshared_secret"], { - "user_id": user, - }) +from .base import routes, set_config +from .middleware import auth, error +from .auth import web as _ +from .plugin import web as _ def init(cfg: Config, loop: AbstractEventLoop) -> web.Application: - global config - config = cfg - from .middleware import auth, error - from .auth import web as _ - from .plugin import web as _ + set_config(cfg) app = web.Application(loop=loop, middlewares=[auth, error]) app.add_routes(routes) return app diff --git a/maubot/management/api/auth.py b/maubot/management/api/auth.py index d08ca1c..1b0bcf3 100644 --- a/maubot/management/api/auth.py +++ b/maubot/management/api/auth.py @@ -16,10 +16,26 @@ from aiohttp import web import json -from . import routes, config, create_token +from mautrix.types import UserID +from mautrix.util.signed_token import sign_token, verify_token + +from .base import routes, get_config from .responses import ErrBadAuth, ErrBodyNotJSON +def is_valid_token(token: str) -> bool: + data = verify_token(get_config()["server.unshared_secret"], token) + if not data: + return False + return get_config().is_admin(data.get("user_id", None)) + + +def create_token(user: UserID) -> str: + return sign_token(get_config()["server.unshared_secret"], { + "user_id": user, + }) + + @routes.post("/login") async def login(request: web.Request) -> web.Response: try: diff --git a/maubot/management/api/base.py b/maubot/management/api/base.py new file mode 100644 index 0000000..d9c2077 --- /dev/null +++ b/maubot/management/api/base.py @@ -0,0 +1,30 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2018 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from aiohttp import web + +from ...config import Config + +routes: web.RouteTableDef = web.RouteTableDef() +_config: Config = None + + +def set_config(config: Config) -> None: + global _config + _config = config + + +def get_config() -> Config: + return _config diff --git a/maubot/management/api/middleware.py b/maubot/management/api/middleware.py index 61e7097..fa5b93a 100644 --- a/maubot/management/api/middleware.py +++ b/maubot/management/api/middleware.py @@ -17,7 +17,7 @@ from typing import Callable, Awaitable from aiohttp import web from .responses import ErrNoToken, ErrInvalidToken, ErrPathNotFound, ErrMethodNotAllowed -from . import is_valid_token +from .auth import is_valid_token Handler = Callable[[web.Request], Awaitable[web.Response]] diff --git a/maubot/management/api/plugin.py b/maubot/management/api/plugin.py index 644c158..4bdcbbd 100644 --- a/maubot/management/api/plugin.py +++ b/maubot/management/api/plugin.py @@ -21,7 +21,7 @@ import os.path from ...loader import PluginLoader, ZippedPluginLoader, MaubotZipImportError from .responses import (ErrPluginNotFound, ErrPluginInUse, plugin_import_error, plugin_reload_error, RespDeleted, RespOK, ErrUnsupportedPluginLoader) -from . import routes, config +from .base import routes, get_config def _plugin_to_dict(plugin: PluginLoader) -> dict: @@ -74,7 +74,7 @@ async def reload_plugin(request: web.Request) -> web.Response: async def upload_new_plugin(content: bytes, pid: str, version: str) -> web.Response: - path = os.path.join(config["plugin_directories.upload"], f"{pid}-v{version}.mbp") + path = os.path.join(get_config()["plugin_directories.upload"], f"{pid}-v{version}.mbp") with open(path, "wb") as p: p.write(content) try: