diff --git a/maubot/__main__.py b/maubot/__main__.py index 2ef73f9..f07b58c 100644 --- a/maubot/__main__.py +++ b/maubot/__main__.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,12 +13,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import logging.config -import argparse import asyncio -import signal -import copy -import sys + +from mautrix.util.program import Program from .config import Config from .db import init as init_db @@ -27,70 +24,58 @@ from .client import Client, init as init_client_class from .loader.zip import init as init_zip_loader from .instance import init as init_plugin_instance_class from .management.api import init as init_mgmt_api +from .lib.future_awaitable import FutureAwaitable from .__meta__ import __version__ -parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.", - prog="python -m maubot") -parser.add_argument("-c", "--config", type=str, default="config.yaml", - metavar="", help="the path to your config file") -parser.add_argument("-b", "--base-config", type=str, default="example-config.yaml", - metavar="", help="the path to the example config " - "(for automatic config updates)") -args = parser.parse_args() -config = Config(args.config, args.base_config) -config.load() -config.update() +class Maubot(Program): + config: Config + server: MaubotServer -logging.config.dictConfig(copy.deepcopy(config["logging"])) + config_class = Config + module = "maubot" + name = "maubot" + version = __version__ + command = "python -m maubot" + description = "A plugin-based Matrix bot system." -loop = asyncio.get_event_loop() + def prepare_log_websocket(self) -> None: + from .management.api.log import init, stop_all + init(self.loop) + self.add_shutdown_actions(FutureAwaitable(stop_all)) -stop_log_listener = None -if config["api_features.log"]: - from .management.api.log import init as init_log_listener, stop_all as stop_log_listener + def prepare(self) -> None: + super().prepare() - init_log_listener(loop) + if self.config["api_features.log"]: + self.prepare_log_websocket() -log = logging.getLogger("maubot.init") -log.info(f"Initializing maubot {__version__}") + init_zip_loader(self.config) + init_db(self.config) + clients = init_client_class(self.config, self.loop) + self.add_startup_actions(*(client.start() for client in clients)) + management_api = init_mgmt_api(self.config, self.loop) + self.server = MaubotServer(management_api, self.config, self.loop) -init_zip_loader(config) -db_engine = init_db(config) -clients = init_client_class(config, loop) -management_api = init_mgmt_api(config, loop) -server = MaubotServer(management_api, config, loop) -plugins = init_plugin_instance_class(config, server, loop) + plugins = init_plugin_instance_class(self.config, self.server, self.loop) + for plugin in plugins: + plugin.load() -for plugin in plugins: - plugin.load() + async def start(self) -> None: + if Client.crypto_db: + self.log.debug("Starting client crypto database") + await Client.crypto_db.start() + await super().start() + await self.server.start() -signal.signal(signal.SIGINT, signal.default_int_handler) -signal.signal(signal.SIGTERM, signal.default_int_handler) + async def stop(self) -> None: + self.add_shutdown_actions(*(client.stop() for client in Client.cache.values())) + await super().stop() + self.log.debug("Stopping server") + try: + await asyncio.wait_for(self.server.stop(), 5) + except asyncio.TimeoutError: + self.log.warning("Stopping server timed out") -try: - log.info("Starting server") - loop.run_until_complete(server.start()) - if Client.crypto_db: - log.debug("Starting client crypto database") - loop.run_until_complete(Client.crypto_db.start()) - log.info("Starting clients and plugins") - loop.run_until_complete(asyncio.gather(*[client.start() for client in clients])) - log.info("Startup actions complete, running forever") - loop.run_forever() -except KeyboardInterrupt: - log.info("Interrupt received, stopping clients") - loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()])) - if stop_log_listener is not None: - log.debug("Closing websockets") - loop.run_until_complete(stop_log_listener()) - log.debug("Stopping server") - try: - loop.run_until_complete(asyncio.wait_for(server.stop(), 5, loop=loop)) - except asyncio.TimeoutError: - log.warning("Stopping server timed out") - log.debug("Closing event loop") - loop.close() - log.debug("Everything stopped, shutting down") - sys.exit(0) +Maubot().run() diff --git a/maubot/client.py b/maubot/client.py index 0ffcd00..60fd335 100644 --- a/maubot/client.py +++ b/maubot/client.py @@ -14,17 +14,15 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING -from os import path import asyncio import logging from aiohttp import ClientSession -from yarl import URL -from mautrix.errors import MatrixInvalidToken, MatrixRequestError +from mautrix.errors import MatrixInvalidToken from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership, StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter, - PresenceState, StateFilter) + PresenceState, StateFilter, DeviceID) from mautrix.client import InternalEventType from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore @@ -33,13 +31,12 @@ from .db import DBClient from .matrix import MaubotMatrixClient try: - from mautrix.crypto import (OlmMachine, StateStore as CryptoStateStore, CryptoStore, - PickleCryptoStore) + from mautrix.crypto import OlmMachine, StateStore as CryptoStateStore, CryptoStore class SQLStateStore(BaseSQLStateStore, CryptoStateStore): pass -except ImportError: +except ImportError as e: OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None SQLStateStore = BaseSQLStateStore @@ -63,8 +60,7 @@ class Client: cache: Dict[UserID, 'Client'] = {} http_client: ClientSession = None global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore() - crypto_pickle_dir: str = None - crypto_db: 'AsyncDatabase' = None + crypto_db: Optional['AsyncDatabase'] = None references: Set['PluginInstance'] db_instance: DBClient @@ -90,7 +86,7 @@ class Client: log=self.log, loop=self.loop, device_id=self.device_id, sync_store=SyncStoreProxy(self.db_instance), state_store=self.global_state_store) - if OlmMachine and self.device_id and (self.crypto_db or self.crypto_pickle_dir): + if OlmMachine and self.device_id and self.crypto_db: self.crypto_store = self._make_crypto_store() self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store) self.client.crypto = self.crypto @@ -109,9 +105,6 @@ class Client: def _make_crypto_store(self) -> 'CryptoStore': if self.crypto_db: return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db) - elif self.crypto_pickle_dir: - return PickleCryptoStore(account_id=self.id, pickle_key="maubot.crypto", - path=path.join(self.crypto_pickle_dir, f"{self.id}.pickle")) raise ValueError("Crypto database not configured") def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]: @@ -330,7 +323,7 @@ class Client: return self.db_instance.access_token @property - def device_id(self) -> str: + def device_id(self) -> DeviceID: return self.db_instance.device_id @property @@ -403,25 +396,9 @@ def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]: Client.loop = loop if OlmMachine: - db_type = config["crypto_database.type"] - if db_type == "default": + db_url = config["crypto_database"] + if db_url == "default": db_url = config["database"] - parsed_url = URL(db_url) - if parsed_url.scheme == "sqlite": - Client.crypto_pickle_dir = config["crypto_database.pickle_dir"] - elif parsed_url.scheme == "postgres" or parsed_url.scheme == "postgresql": - if not PgCryptoStore: - log.warning("Default database is postgres, but asyncpg is not installed. " - "Encryption will not work.") - else: - Client.crypto_db = AsyncDatabase(url=db_url, - upgrade_table=PgCryptoStore.upgrade_table) - elif db_type == "pickle": - Client.crypto_pickle_dir = config["crypto_database.pickle_dir"] - elif (db_type == "postgres" or db_type == "postgresql") and PgCryptoStore: - Client.crypto_db = AsyncDatabase(url=config["crypto_database.postgres_uri"], - upgrade_table=PgCryptoStore.upgrade_table) - else: - raise ValueError("Unsupported crypto database type") + Client.crypto_db = AsyncDatabase.create(db_url, upgrade_table=PgCryptoStore.upgrade_table) return Client.all() diff --git a/maubot/config.py b/maubot/config.py index 34466cc..eb5072b 100644 --- a/maubot/config.py +++ b/maubot/config.py @@ -32,9 +32,11 @@ class Config(BaseFileConfig): base = helper.base copy = helper.copy copy("database") - copy("crypto_database.type") - copy("crypto_database.postgres_uri") - copy("crypto_database.pickle_dir") + if isinstance(self["crypto_database"], dict): + if self["crypto_database.type"] == "postgres": + base["crypto_database"] = self["crypto_database.postgres_uri"] + else: + copy("crypto_database") copy("plugin_directories.upload") copy("plugin_directories.load") copy("plugin_directories.trash") diff --git a/example-config.yaml b/maubot/example-config.yaml similarity index 81% rename from example-config.yaml rename to maubot/example-config.yaml index fbb0183..8f3e288 100644 --- a/example-config.yaml +++ b/maubot/example-config.yaml @@ -2,22 +2,11 @@ # Other DBMSes supported by SQLAlchemy may or may not work. # Format examples: # SQLite: sqlite:///filename.db -# Postgres: postgres://username:password@hostname/dbname +# Postgres: postgresql://username:password@hostname/dbname database: sqlite:///maubot.db -# Database for encryption data. -crypto_database: - # Type of database. Either "default", "pickle" or "postgres". - # When set to default, using SQLite as the main database will use pickle as the crypto database - # and using Postgres as the main database will use the same one as the crypto database. - # - # When using pickle, individual crypto databases are stored in the pickle_dir directory. - # When using non-default postgres, postgres_uri is used to connect to postgres. - # - # WARNING: The pickle database is dangerous and should not be used in production. - type: default - postgres_uri: postgres://username:password@hostname/dbname - pickle_dir: ./crypto +# Separate database URL for the crypto database. "default" means use the same database as above. +crypto_database: default plugin_directories: # The directory where uploaded new plugins should be stored. diff --git a/maubot/lib/future_awaitable.py b/maubot/lib/future_awaitable.py new file mode 100644 index 0000000..b55dcb6 --- /dev/null +++ b/maubot/lib/future_awaitable.py @@ -0,0 +1,9 @@ +from typing import Callable, Awaitable, Generator, Any + +class FutureAwaitable: + def __init__(self, func: Callable[[], Awaitable[None]]) -> None: + self._func = func + + def __await__(self) -> Generator[Any, None, None]: + return self._func().__await__() + diff --git a/maubot/management/api/log.py b/maubot/management/api/log.py index d6ec092..3ed5ca1 100644 --- a/maubot/management/api/log.py +++ b/maubot/management/api/log.py @@ -93,6 +93,7 @@ def init(loop: asyncio.AbstractEventLoop) -> None: async def stop_all() -> None: + log.debug("Closing log listener websockets") log_root.removeHandler(handler) for socket in sockets: try: diff --git a/optional-requirements.txt b/optional-requirements.txt index eeef618..0ab6975 100644 --- a/optional-requirements.txt +++ b/optional-requirements.txt @@ -3,9 +3,10 @@ #/postgres psycopg2-binary>=2,<3 +asyncpg>=0.20,<0.26 #/e2be -asyncpg>=0.20,<0.25 +aiosqlite>=0.16,<0.18 python-olm>=3,<4 pycryptodome>=3,<4 unpaddedbase64>=1,<2 diff --git a/requirements.txt b/requirements.txt index 3420a56..bc888b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -mautrix>=0.10.9,<0.11 +mautrix>=0.11,<0.12 aiohttp>=3,<4 yarl>=1,<2 SQLAlchemy>=1,<1.4 diff --git a/setup.py b/setup.py index 9286b8a..0fd99ff 100644 --- a/setup.py +++ b/setup.py @@ -57,15 +57,18 @@ setuptools.setup( mbc=maubot.cli:app """, data_files=[ - (".", ["example-config.yaml", "alembic.ini"]), + (".", ["maubot/example-config.yaml", "alembic.ini"]), ("alembic", ["alembic/env.py"]), ("alembic/versions", glob.glob("alembic/versions/*.py")), ], package_data={ - "maubot": ["management/frontend/build/*", - "management/frontend/build/static/css/*", - "management/frontend/build/static/js/*", - "management/frontend/build/static/media/*"], + "maubot": [ + "example-config.yaml", + "management/frontend/build/*", + "management/frontend/build/static/css/*", + "management/frontend/build/static/js/*", + "management/frontend/build/static/media/*", + ], "maubot.cli": ["res/*"], }, )