Refactor __main__.py and fix things

This commit is contained in:
Tulir Asokan 2021-11-19 15:22:54 +02:00
parent c685eb5e08
commit 7c9668d8bc
9 changed files with 84 additions and 117 deletions

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2021 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,12 +13,9 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
import logging.config
import argparse
import asyncio import asyncio
import signal
import copy from mautrix.util.program import Program
import sys
from .config import Config from .config import Config
from .db import init as init_db from .db import init as init_db
@ -27,70 +24,58 @@ from .client import Client, init as init_client_class
from .loader.zip import init as init_zip_loader from .loader.zip import init as init_zip_loader
from .instance import init as init_plugin_instance_class from .instance import init as init_plugin_instance_class
from .management.api import init as init_mgmt_api from .management.api import init as init_mgmt_api
from .lib.future_awaitable import FutureAwaitable
from .__meta__ import __version__ from .__meta__ import __version__
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
prog="python -m maubot")
parser.add_argument("-c", "--config", type=str, default="config.yaml",
metavar="<path>", help="the path to your config file")
parser.add_argument("-b", "--base-config", type=str, default="example-config.yaml",
metavar="<path>", help="the path to the example config "
"(for automatic config updates)")
args = parser.parse_args()
config = Config(args.config, args.base_config) class Maubot(Program):
config.load() config: Config
config.update() server: MaubotServer
logging.config.dictConfig(copy.deepcopy(config["logging"])) config_class = Config
module = "maubot"
name = "maubot"
version = __version__
command = "python -m maubot"
description = "A plugin-based Matrix bot system."
loop = asyncio.get_event_loop() def prepare_log_websocket(self) -> None:
from .management.api.log import init, stop_all
init(self.loop)
self.add_shutdown_actions(FutureAwaitable(stop_all))
stop_log_listener = None def prepare(self) -> None:
if config["api_features.log"]: super().prepare()
from .management.api.log import init as init_log_listener, stop_all as stop_log_listener
init_log_listener(loop) if self.config["api_features.log"]:
self.prepare_log_websocket()
log = logging.getLogger("maubot.init") init_zip_loader(self.config)
log.info(f"Initializing maubot {__version__}") init_db(self.config)
clients = init_client_class(self.config, self.loop)
self.add_startup_actions(*(client.start() for client in clients))
management_api = init_mgmt_api(self.config, self.loop)
self.server = MaubotServer(management_api, self.config, self.loop)
init_zip_loader(config) plugins = init_plugin_instance_class(self.config, self.server, self.loop)
db_engine = init_db(config) for plugin in plugins:
clients = init_client_class(config, loop) plugin.load()
management_api = init_mgmt_api(config, loop)
server = MaubotServer(management_api, config, loop)
plugins = init_plugin_instance_class(config, server, loop)
for plugin in plugins: async def start(self) -> None:
plugin.load() if Client.crypto_db:
self.log.debug("Starting client crypto database")
await Client.crypto_db.start()
await super().start()
await self.server.start()
signal.signal(signal.SIGINT, signal.default_int_handler) async def stop(self) -> None:
signal.signal(signal.SIGTERM, signal.default_int_handler) self.add_shutdown_actions(*(client.stop() for client in Client.cache.values()))
await super().stop()
self.log.debug("Stopping server")
try:
await asyncio.wait_for(self.server.stop(), 5)
except asyncio.TimeoutError:
self.log.warning("Stopping server timed out")
try: Maubot().run()
log.info("Starting server")
loop.run_until_complete(server.start())
if Client.crypto_db:
log.debug("Starting client crypto database")
loop.run_until_complete(Client.crypto_db.start())
log.info("Starting clients and plugins")
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
log.info("Startup actions complete, running forever")
loop.run_forever()
except KeyboardInterrupt:
log.info("Interrupt received, stopping clients")
loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()]))
if stop_log_listener is not None:
log.debug("Closing websockets")
loop.run_until_complete(stop_log_listener())
log.debug("Stopping server")
try:
loop.run_until_complete(asyncio.wait_for(server.stop(), 5, loop=loop))
except asyncio.TimeoutError:
log.warning("Stopping server timed out")
log.debug("Closing event loop")
loop.close()
log.debug("Everything stopped, shutting down")
sys.exit(0)

View File

@ -14,17 +14,15 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING
from os import path
import asyncio import asyncio
import logging import logging
from aiohttp import ClientSession from aiohttp import ClientSession
from yarl import URL
from mautrix.errors import MatrixInvalidToken, MatrixRequestError from mautrix.errors import MatrixInvalidToken
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership, from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter, StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
PresenceState, StateFilter) PresenceState, StateFilter, DeviceID)
from mautrix.client import InternalEventType from mautrix.client import InternalEventType
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
@ -33,13 +31,12 @@ from .db import DBClient
from .matrix import MaubotMatrixClient from .matrix import MaubotMatrixClient
try: try:
from mautrix.crypto import (OlmMachine, StateStore as CryptoStateStore, CryptoStore, from mautrix.crypto import OlmMachine, StateStore as CryptoStateStore, CryptoStore
PickleCryptoStore)
class SQLStateStore(BaseSQLStateStore, CryptoStateStore): class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
pass pass
except ImportError: except ImportError as e:
OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None
SQLStateStore = BaseSQLStateStore SQLStateStore = BaseSQLStateStore
@ -63,8 +60,7 @@ class Client:
cache: Dict[UserID, 'Client'] = {} cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None http_client: ClientSession = None
global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore() global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore()
crypto_pickle_dir: str = None crypto_db: Optional['AsyncDatabase'] = None
crypto_db: 'AsyncDatabase' = None
references: Set['PluginInstance'] references: Set['PluginInstance']
db_instance: DBClient db_instance: DBClient
@ -90,7 +86,7 @@ class Client:
log=self.log, loop=self.loop, device_id=self.device_id, log=self.log, loop=self.loop, device_id=self.device_id,
sync_store=SyncStoreProxy(self.db_instance), sync_store=SyncStoreProxy(self.db_instance),
state_store=self.global_state_store) state_store=self.global_state_store)
if OlmMachine and self.device_id and (self.crypto_db or self.crypto_pickle_dir): if OlmMachine and self.device_id and self.crypto_db:
self.crypto_store = self._make_crypto_store() self.crypto_store = self._make_crypto_store()
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store) self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
self.client.crypto = self.crypto self.client.crypto = self.crypto
@ -109,9 +105,6 @@ class Client:
def _make_crypto_store(self) -> 'CryptoStore': def _make_crypto_store(self) -> 'CryptoStore':
if self.crypto_db: if self.crypto_db:
return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db) return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db)
elif self.crypto_pickle_dir:
return PickleCryptoStore(account_id=self.id, pickle_key="maubot.crypto",
path=path.join(self.crypto_pickle_dir, f"{self.id}.pickle"))
raise ValueError("Crypto database not configured") raise ValueError("Crypto database not configured")
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]: def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
@ -330,7 +323,7 @@ class Client:
return self.db_instance.access_token return self.db_instance.access_token
@property @property
def device_id(self) -> str: def device_id(self) -> DeviceID:
return self.db_instance.device_id return self.db_instance.device_id
@property @property
@ -403,25 +396,9 @@ def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.loop = loop Client.loop = loop
if OlmMachine: if OlmMachine:
db_type = config["crypto_database.type"] db_url = config["crypto_database"]
if db_type == "default": if db_url == "default":
db_url = config["database"] db_url = config["database"]
parsed_url = URL(db_url) Client.crypto_db = AsyncDatabase.create(db_url, upgrade_table=PgCryptoStore.upgrade_table)
if parsed_url.scheme == "sqlite":
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
elif parsed_url.scheme == "postgres" or parsed_url.scheme == "postgresql":
if not PgCryptoStore:
log.warning("Default database is postgres, but asyncpg is not installed. "
"Encryption will not work.")
else:
Client.crypto_db = AsyncDatabase(url=db_url,
upgrade_table=PgCryptoStore.upgrade_table)
elif db_type == "pickle":
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
elif (db_type == "postgres" or db_type == "postgresql") and PgCryptoStore:
Client.crypto_db = AsyncDatabase(url=config["crypto_database.postgres_uri"],
upgrade_table=PgCryptoStore.upgrade_table)
else:
raise ValueError("Unsupported crypto database type")
return Client.all() return Client.all()

View File

@ -32,9 +32,11 @@ class Config(BaseFileConfig):
base = helper.base base = helper.base
copy = helper.copy copy = helper.copy
copy("database") copy("database")
copy("crypto_database.type") if isinstance(self["crypto_database"], dict):
copy("crypto_database.postgres_uri") if self["crypto_database.type"] == "postgres":
copy("crypto_database.pickle_dir") base["crypto_database"] = self["crypto_database.postgres_uri"]
else:
copy("crypto_database")
copy("plugin_directories.upload") copy("plugin_directories.upload")
copy("plugin_directories.load") copy("plugin_directories.load")
copy("plugin_directories.trash") copy("plugin_directories.trash")

View File

@ -2,22 +2,11 @@
# Other DBMSes supported by SQLAlchemy may or may not work. # Other DBMSes supported by SQLAlchemy may or may not work.
# Format examples: # Format examples:
# SQLite: sqlite:///filename.db # SQLite: sqlite:///filename.db
# Postgres: postgres://username:password@hostname/dbname # Postgres: postgresql://username:password@hostname/dbname
database: sqlite:///maubot.db database: sqlite:///maubot.db
# Database for encryption data. # Separate database URL for the crypto database. "default" means use the same database as above.
crypto_database: crypto_database: default
# Type of database. Either "default", "pickle" or "postgres".
# When set to default, using SQLite as the main database will use pickle as the crypto database
# and using Postgres as the main database will use the same one as the crypto database.
#
# When using pickle, individual crypto databases are stored in the pickle_dir directory.
# When using non-default postgres, postgres_uri is used to connect to postgres.
#
# WARNING: The pickle database is dangerous and should not be used in production.
type: default
postgres_uri: postgres://username:password@hostname/dbname
pickle_dir: ./crypto
plugin_directories: plugin_directories:
# The directory where uploaded new plugins should be stored. # The directory where uploaded new plugins should be stored.

View File

@ -0,0 +1,9 @@
from typing import Callable, Awaitable, Generator, Any
class FutureAwaitable:
def __init__(self, func: Callable[[], Awaitable[None]]) -> None:
self._func = func
def __await__(self) -> Generator[Any, None, None]:
return self._func().__await__()

View File

@ -93,6 +93,7 @@ def init(loop: asyncio.AbstractEventLoop) -> None:
async def stop_all() -> None: async def stop_all() -> None:
log.debug("Closing log listener websockets")
log_root.removeHandler(handler) log_root.removeHandler(handler)
for socket in sockets: for socket in sockets:
try: try:

View File

@ -3,9 +3,10 @@
#/postgres #/postgres
psycopg2-binary>=2,<3 psycopg2-binary>=2,<3
asyncpg>=0.20,<0.26
#/e2be #/e2be
asyncpg>=0.20,<0.25 aiosqlite>=0.16,<0.18
python-olm>=3,<4 python-olm>=3,<4
pycryptodome>=3,<4 pycryptodome>=3,<4
unpaddedbase64>=1,<2 unpaddedbase64>=1,<2

View File

@ -1,4 +1,4 @@
mautrix>=0.10.9,<0.11 mautrix>=0.11,<0.12
aiohttp>=3,<4 aiohttp>=3,<4
yarl>=1,<2 yarl>=1,<2
SQLAlchemy>=1,<1.4 SQLAlchemy>=1,<1.4

View File

@ -57,15 +57,18 @@ setuptools.setup(
mbc=maubot.cli:app mbc=maubot.cli:app
""", """,
data_files=[ data_files=[
(".", ["example-config.yaml", "alembic.ini"]), (".", ["maubot/example-config.yaml", "alembic.ini"]),
("alembic", ["alembic/env.py"]), ("alembic", ["alembic/env.py"]),
("alembic/versions", glob.glob("alembic/versions/*.py")), ("alembic/versions", glob.glob("alembic/versions/*.py")),
], ],
package_data={ package_data={
"maubot": ["management/frontend/build/*", "maubot": [
"management/frontend/build/static/css/*", "example-config.yaml",
"management/frontend/build/static/js/*", "management/frontend/build/*",
"management/frontend/build/static/media/*"], "management/frontend/build/static/css/*",
"management/frontend/build/static/js/*",
"management/frontend/build/static/media/*",
],
"maubot.cli": ["res/*"], "maubot.cli": ["res/*"],
}, },
) )