Refactor things and implement instance API

This commit is contained in:
Tulir Asokan 2018-11-01 18:11:54 +02:00
parent cbeff0c0cb
commit bc87b2a02b
14 changed files with 249 additions and 100 deletions

View File

@ -13,8 +13,6 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import orm
import sqlalchemy as sql
import logging.config
import argparse
import asyncio
@ -23,11 +21,11 @@ import copy
import sys
from .config import Config
from .db import Base, init as init_db
from .db import init as init_db
from .server import MaubotServer
from .client import Client, init as init_client
from .loader import ZippedPluginLoader
from .instance import PluginInstance, init as init_plugin_instance_class
from .client import Client, init as init_client_class
from .loader.zip import init as init_zip_loader
from .instance import init as init_plugin_instance_class
from .management.api import init as init_management
from .__meta__ import __version__
@ -46,57 +44,48 @@ config.update()
logging.config.dictConfig(copy.deepcopy(config["logging"]))
log = logging.getLogger("maubot.init")
log.debug(f"Initializing maubot {__version__}")
db_engine: sql.engine.Engine = sql.create_engine(config["database"])
db_factory = orm.sessionmaker(bind=db_engine)
db_session = orm.scoping.scoped_session(db_factory)
Base.metadata.bind = db_engine
Base.metadata.create_all()
log.info(f"Initializing maubot {__version__}")
loop = asyncio.get_event_loop()
init_db(db_session)
clients = init_client(loop)
init_plugin_instance_class(db_session, config, loop)
init_zip_loader(config)
db_session = init_db(config)
clients = init_client_class(db_session, loop)
plugins = init_plugin_instance_class(db_session, config, loop)
management_api = init_management(config, loop)
server = MaubotServer(config, management_api, loop)
ZippedPluginLoader.trash_path = config["plugin_directories.trash"]
ZippedPluginLoader.directories = config["plugin_directories.load"]
ZippedPluginLoader.load_all()
plugins = PluginInstance.all()
for plugin in plugins:
plugin.load()
signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGTERM, signal.default_int_handler)
stop = False
async def periodic_commit():
while not stop:
while True:
await asyncio.sleep(60)
db_session.commit()
periodic_commit_task: asyncio.Future = None
try:
log.debug("Starting server")
log.info("Starting server")
loop.run_until_complete(server.start())
log.debug("Starting clients and plugins")
log.info("Starting clients and plugins")
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
log.debug("Startup actions complete, running forever")
loop.run_until_complete(periodic_commit())
log.info("Startup actions complete, running forever")
periodic_commit_task = asyncio.ensure_future(periodic_commit(), loop=loop)
loop.run_forever()
except KeyboardInterrupt:
log.debug("Interrupt received, stopping HTTP clients/servers and saving database")
stop = True
if periodic_commit_task is not None:
periodic_commit_task.cancel()
for client in Client.cache.values():
client.stop()
db_session.commit()
loop.run_until_complete(server.stop())
loop.close()
log.debug("Everything stopped, shutting down")
sys.exit(0)

View File

@ -14,10 +14,12 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional, Set, TYPE_CHECKING
from aiohttp import ClientSession
import asyncio
import logging
from sqlalchemy.orm import Session
from aiohttp import ClientSession
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
EventType, Filter, RoomFilter, RoomEventFilter)
@ -32,6 +34,7 @@ log = logging.getLogger("maubot.client")
class Client:
db: Session = None
log: logging.Logger = None
loop: asyncio.AbstractEventLoop = None
cache: Dict[UserID, 'Client'] = {}
@ -73,12 +76,12 @@ class Client:
user_id = await self.client.whoami()
except MatrixInvalidToken as e:
self.log.error(f"Invalid token: {e}. Disabling client")
self.enabled = False
self.db_instance.enabled = False
return
except MatrixRequestError:
if try_n >= 5:
self.log.exception("Failed to get /account/whoami, disabling client")
self.enabled = False
self.db_instance.enabled = False
else:
self.log.exception(f"Failed to get /account/whoami, "
f"retrying in {(try_n + 1) * 10}s")
@ -86,7 +89,7 @@ class Client:
return
if user_id != self.id:
self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}")
self.enabled = False
self.db_instance.enabled = False
return
if not self.filter_id:
self.filter_id = await self.client.create_filter(Filter(
@ -100,8 +103,7 @@ class Client:
await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url)
if self.sync:
self.client.start(self.filter_id)
self.start_sync()
self.started = True
self.log.info("Client started, starting plugin instances...")
await self.start_plugins()
@ -110,12 +112,19 @@ class Client:
await asyncio.gather(*[plugin.start() for plugin in self.references], loop=self.loop)
async def stop_plugins(self) -> None:
await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.running],
await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.started],
loop=self.loop)
def start_sync(self) -> None:
if self.sync:
self.client.start(self.filter_id)
def stop_sync(self) -> None:
self.client.stop()
def stop(self) -> None:
self.started = False
self.client.stop()
self.stop_sync()
def to_dict(self) -> dict:
return {
@ -233,7 +242,8 @@ class Client:
# endregion
def init(loop: asyncio.AbstractEventLoop) -> List[Client]:
def init(db: Session, loop: asyncio.AbstractEventLoop) -> List[Client]:
Client.db = db
Client.http_client = ClientSession(loop=loop)
Client.loop = loop
return Client.all()

View File

@ -13,12 +13,17 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import cast
from sqlalchemy import Column, String, Boolean, ForeignKey, Text
from sqlalchemy.orm import Query, scoped_session
from sqlalchemy.orm import Query, Session, sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base
import sqlalchemy as sql
from mautrix.types import UserID, FilterID, SyncToken, ContentURI
from .config import Config
Base: declarative_base = declarative_base()
@ -54,6 +59,14 @@ class DBClient(Base):
avatar_url: ContentURI = Column(String(255), nullable=False, default="")
def init(session: scoped_session) -> None:
DBPlugin.query = session.query_property()
DBClient.query = session.query_property()
def init(config: Config) -> Session:
db_engine: sql.engine.Engine = sql.create_engine(config["database"])
db_factory = sessionmaker(bind=db_engine)
db_session = scoped_session(db_factory)
Base.metadata.bind = db_engine
Base.metadata.create_all()
DBPlugin.query = db_session.query_property()
DBClient.query = db_session.query_property()
return cast(Session, db_session)

View File

@ -48,13 +48,14 @@ class PluginInstance:
client: Client
plugin: Plugin
config: BaseProxyConfig
running: bool
base_cfg: RecursiveDict[CommentedMap]
started: bool
def __init__(self, db_instance: DBPlugin):
self.db_instance = db_instance
self.log = logging.getLogger(f"maubot.plugin.{self.id}")
self.config = None
self.running = False
self.started = False
self.cache[self.id] = self
def to_dict(self) -> dict:
@ -62,7 +63,7 @@ class PluginInstance:
"id": self.id,
"type": self.type,
"enabled": self.enabled,
"running": self.running,
"started": self.started,
"primary_user": self.primary_user,
}
@ -71,19 +72,26 @@ class PluginInstance:
self.loader = PluginLoader.find(self.type)
except KeyError:
self.log.error(f"Failed to find loader for type {self.type}")
self.enabled = False
self.db_instance.enabled = False
return
self.client = Client.get(self.primary_user)
if not self.client:
self.log.error(f"Failed to get client for user {self.primary_user}")
self.enabled = False
self.db_instance.enabled = False
return
self.log.debug("Plugin instance dependencies loaded")
self.loader.references.add(self)
self.client.references.add(self)
def delete(self) -> None:
if self.loader is not None:
self.loader.references.remove(self)
if self.client is not None:
self.client.references.remove(self)
try:
del self.cache[self.id]
except KeyError:
pass
self.db.delete(self.db_instance)
# TODO delete plugin db
@ -96,7 +104,7 @@ class PluginInstance:
self.db_instance.config = buf.getvalue()
async def start(self) -> None:
if self.running:
if self.started:
self.log.warning("Ignoring start() call to already started plugin")
return
elif not self.enabled:
@ -107,28 +115,28 @@ class PluginInstance:
if config_class:
try:
base = await self.loader.read_file("base-config.yaml")
base_file = RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap)
self.base_cfg = RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap)
except (FileNotFoundError, KeyError):
base_file = None
self.config = config_class(self.load_config, lambda: base_file, self.save_config)
self.base_cfg = None
self.config = config_class(self.load_config, lambda: self.base_cfg, self.save_config)
self.plugin = cls(self.client.client, self.loop, self.client.http_client, self.id,
self.log, self.config, self.mb_config["plugin_directories.db"])
try:
await self.plugin.start()
except Exception:
self.log.exception("Failed to start instance")
self.enabled = False
self.db_instance.enabled = False
return
self.running = True
self.started = True
self.log.info(f"Started instance of {self.loader.id} v{self.loader.version} "
f"with user {self.client.id}")
async def stop(self) -> None:
if not self.running:
if not self.started:
self.log.warning("Ignoring stop() call to non-running plugin")
return
self.log.debug("Stopping plugin instance...")
self.running = False
self.started = False
try:
await self.plugin.stop()
except Exception:
@ -150,6 +158,37 @@ class PluginInstance:
def all(cls) -> List['PluginInstance']:
return [cls.get(plugin.id, plugin) for plugin in DBPlugin.query.all()]
def update_id(self, new_id: str) -> None:
if new_id is not None and new_id != self.id:
self.db_instance.id = new_id
def update_config(self, config: str) -> None:
if not config or self.db_instance.config == config:
return
self.db_instance.config = config
if self.started and self.plugin is not None:
self.plugin.on_external_config_update()
async def update_primary_user(self, primary_user: UserID) -> bool:
client = Client.get(primary_user)
if not client:
return False
await self.stop()
self.db_instance.primary_user = client.id
self.client.references.remove(self)
self.client = client
await self.start()
self.log.debug(f"Primary user switched to {self.client.id}")
return True
async def update_started(self, started: bool) -> None:
if started is not None and started != self.started:
await (self.start() if started else self.stop())
def update_enabled(self, enabled: bool) -> None:
if enabled is not None and enabled != self.enabled:
self.db_instance.enabled = enabled
# region Properties
@property
@ -168,22 +207,15 @@ class PluginInstance:
def enabled(self) -> bool:
return self.db_instance.enabled
@enabled.setter
def enabled(self, value: bool) -> None:
self.db_instance.enabled = value
@property
def primary_user(self) -> UserID:
return self.db_instance.primary_user
@primary_user.setter
def primary_user(self, value: UserID) -> None:
self.db_instance.primary_user = value
# endregion
def init(db: Session, config: Config, loop: AbstractEventLoop):
def init(db: Session, config: Config, loop: AbstractEventLoop) -> List[PluginInstance]:
PluginInstance.db = db
PluginInstance.mb_config = config
PluginInstance.loop = loop
return PluginInstance.all()

View File

@ -61,7 +61,7 @@ class PluginLoader(ABC):
async def stop_instances(self) -> None:
await asyncio.gather(*[instance.stop() for instance
in self.references if instance.running])
in self.references if instance.started])
async def start_instances(self) -> None:
await asyncio.gather(*[instance.start() for instance

View File

@ -23,6 +23,7 @@ import os
from ..lib.zipimport import zipimporter, ZipImportError
from ..plugin_base import Plugin
from ..config import Config
from .abc import PluginLoader, PluginClass, IDConflictError
@ -264,3 +265,9 @@ class ZippedPluginLoader(PluginLoader):
except IDConflictError:
cls.log.error(f"Duplicate plugin ID at {path}, trashing...")
cls.trash(path)
def init(config: Config) -> None:
ZippedPluginLoader.trash_path = config["plugin_directories.trash"]
ZippedPluginLoader.directories = config["plugin_directories.load"]
ZippedPluginLoader.load_all()

View File

@ -21,6 +21,8 @@ from .base import routes, set_config
from .middleware import auth, error
from .auth import web as _
from .plugin import web as _
from .instance import web as _
from .client import web as _
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:

View File

@ -13,27 +13,57 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from json import JSONDecodeError
from aiohttp import web
from mautrix.types import UserID
from ...client import Client
from .base import routes
from .responses import ErrNotImplemented
from .responses import ErrNotImplemented, ErrClientNotFound, ErrBodyNotJSON
@routes.get("/clients")
def get_clients(request: web.Request) -> web.Response:
return ErrNotImplemented
async def get_clients(request: web.Request) -> web.Response:
return web.json_response([client.to_dict() for client in Client.cache.values()])
@routes.get("/client/{id}")
def get_client(request: web.Request) -> web.Response:
async def get_client(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
if not client:
return ErrClientNotFound
return web.json_response(client.to_dict())
async def create_client(user_id: UserID, data: dict) -> web.Response:
return ErrNotImplemented
async def update_client(client: Client, data: dict) -> web.Response:
return ErrNotImplemented
@routes.put("/client/{id}")
def update_client(request: web.Request) -> web.Response:
return ErrNotImplemented
async def update_client(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
try:
data = await request.json()
except JSONDecodeError:
return ErrBodyNotJSON
if not client:
return await create_client(user_id, data)
else:
return await update_client(client, data)
@routes.delete("/client/{id}")
def delete_client(request: web.Request) -> web.Response:
async def delete_client(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
if not client:
return ErrClientNotFound
return ErrNotImplemented

View File

@ -13,57 +13,88 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
from json import JSONDecodeError
from mautrix.types import UserID
from aiohttp import web
from ...db import DBClient
from ...db import DBPlugin
from ...instance import PluginInstance
from ...loader import PluginLoader
from ...client import Client
from .base import routes
from .responses import ErrNotImplemented, ErrClientNotFound, ErrBodyNotJSON
from .responses import (ErrInstanceNotFound, ErrBodyNotJSON, RespDeleted, ErrPrimaryUserNotFound,
ErrPluginTypeRequired, ErrPrimaryUserRequired, ErrPluginTypeNotFound)
@routes.get("/instances")
async def get_instances(_: web.Request) -> web.Response:
return web.json_response([client.to_dict() for client in Client.cache.values()])
return web.json_response([instance.to_dict() for instance in PluginInstance.cache.values()])
@routes.get("/instance/{id}")
async def get_instance(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
if not client:
return ErrClientNotFound
return web.json_response(client.to_dict())
instance_id = request.match_info.get("id", "").lower()
instance = PluginInstance.get(instance_id, None)
if not instance:
return ErrInstanceNotFound
return web.json_response(instance.to_dict())
async def create_instance(user_id: UserID, data: dict) -> web.Response:
return ErrNotImplemented
async def create_instance(instance_id: str, data: dict) -> web.Response:
plugin_type = data.get("type", None)
primary_user = data.get("primary_user", None)
if not plugin_type:
return ErrPluginTypeRequired
elif not primary_user:
return ErrPrimaryUserRequired
elif not Client.get(primary_user):
return ErrPrimaryUserNotFound
try:
PluginLoader.find(plugin_type)
except KeyError:
return ErrPluginTypeNotFound
db_instance = DBPlugin(id=instance_id, type=plugin_type, enabled=data.get("enabled", True),
primary_user=primary_user, config=data.get("config", ""))
instance = PluginInstance(db_instance)
instance.load()
PluginInstance.db.add(db_instance)
PluginInstance.db.commit()
await instance.start()
return web.json_response(instance.to_dict())
async def update_instance(client: Client, data: dict) -> web.Response:
return ErrNotImplemented
async def update_instance(instance: PluginInstance, data: dict) -> web.Response:
if not await instance.update_primary_user(data.get("primary_user")):
return ErrPrimaryUserNotFound
instance.update_id(data.get("id", None))
instance.update_enabled(data.get("enabled", None))
instance.update_config(data.get("config", None))
await instance.update_started(data.get("started", None))
instance.db.commit()
return web.json_response(instance.to_dict())
@routes.put("/instance/{id}")
async def update_instance(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
instance_id = request.match_info.get("id", "").lower()
instance = PluginInstance.get(instance_id, None)
try:
data = await request.json()
except JSONDecodeError:
return ErrBodyNotJSON
if not client:
return await create_instance(user_id, data)
if not instance:
return await create_instance(instance_id, data)
else:
return await update_instance(client, data)
return await update_instance(instance, data)
@routes.delete("/instance/{id}")
async def delete_instance(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
if not client:
return ErrClientNotFound
return ErrNotImplemented
instance_id = request.match_info.get("id", "").lower()
instance = PluginInstance.get(instance_id, None)
if not instance:
return ErrInstanceNotFound
if instance.started:
await instance.stop()
instance.delete()
return RespDeleted

View File

@ -41,6 +41,31 @@ ErrClientNotFound = web.json_response({
"errcode": "client_not_found",
}, status=HTTPStatus.NOT_FOUND)
ErrPrimaryUserNotFound = web.json_response({
"error": "Client for given primary user not found",
"errcode": "primary_user_not_found",
}, status=HTTPStatus.NOT_FOUND)
ErrInstanceNotFound = web.json_response({
"error": "Plugin instance not found",
"errcode": "instance_not_found",
}, status=HTTPStatus.NOT_FOUND)
ErrPluginTypeNotFound = web.json_response({
"error": "Given plugin type not found",
"errcode": "plugin_type_not_found",
}, status=HTTPStatus.NOT_FOUND)
ErrPluginTypeRequired = web.json_response({
"error": "Plugin type is required when creating plugin instances",
"errcode": "plugin_type_required",
}, status=HTTPStatus.BAD_REQUEST)
ErrPrimaryUserRequired = web.json_response({
"error": "Primary user is required when creating plugin instances",
"errcode": "primary_user_required",
}, status=HTTPStatus.BAD_REQUEST)
ErrPathNotFound = web.json_response({
"error": "Resource not found",
"errcode": "resource_not_found",

View File

@ -346,6 +346,9 @@ components:
primary_user:
type: string
example: '@putkiteippi:maunium.net'
config:
type: string
example: "YAML"
MatrixClient:
type: object
properties:

View File

@ -15,7 +15,7 @@
// along with this program. If not, see <https://www.gnu.org/licenses/>.
import React from "react"
import ReactDOM from "react-dom"
import "./style/base"
import "./style/index.sass"
import MaubotManager from "./MaubotManager"
ReactDOM.render(<MaubotManager/>, document.getElementById("root"))

View File

@ -28,7 +28,6 @@ if TYPE_CHECKING:
from .command_spec import CommandSpec
from mautrix.util.config import BaseProxyConfig
DatabaseNotConfigured = ValueError("A database for this maubot instance has not been configured.")
@ -69,3 +68,7 @@ class Plugin(ABC):
@classmethod
def get_config_class(cls) -> Optional[Type['BaseProxyConfig']]:
return None
def on_external_config_update(self) -> None:
if self.config:
self.config.load_and_update()

View File

@ -48,4 +48,8 @@ setuptools.setup(
data_files=[
(".", ["example-config.yaml"]),
],
package_data={
"maubot": ["management/frontend/build/*", "management/frontend/build/static/css/*",
"management/frontend/build/static/js/*"],
},
)