Refactor __main__.py and fix things

This commit is contained in:
Tulir Asokan 2021-11-19 15:22:54 +02:00
parent c685eb5e08
commit 7c9668d8bc
9 changed files with 84 additions and 117 deletions

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2021 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,12 +13,9 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import logging.config
import argparse
import asyncio
import signal
import copy
import sys
from mautrix.util.program import Program
from .config import Config
from .db import init as init_db
@ -27,70 +24,58 @@ from .client import Client, init as init_client_class
from .loader.zip import init as init_zip_loader
from .instance import init as init_plugin_instance_class
from .management.api import init as init_mgmt_api
from .lib.future_awaitable import FutureAwaitable
from .__meta__ import __version__
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
prog="python -m maubot")
parser.add_argument("-c", "--config", type=str, default="config.yaml",
metavar="<path>", help="the path to your config file")
parser.add_argument("-b", "--base-config", type=str, default="example-config.yaml",
metavar="<path>", help="the path to the example config "
"(for automatic config updates)")
args = parser.parse_args()
config = Config(args.config, args.base_config)
config.load()
config.update()
class Maubot(Program):
config: Config
server: MaubotServer
logging.config.dictConfig(copy.deepcopy(config["logging"]))
config_class = Config
module = "maubot"
name = "maubot"
version = __version__
command = "python -m maubot"
description = "A plugin-based Matrix bot system."
loop = asyncio.get_event_loop()
def prepare_log_websocket(self) -> None:
from .management.api.log import init, stop_all
init(self.loop)
self.add_shutdown_actions(FutureAwaitable(stop_all))
stop_log_listener = None
if config["api_features.log"]:
from .management.api.log import init as init_log_listener, stop_all as stop_log_listener
def prepare(self) -> None:
super().prepare()
init_log_listener(loop)
if self.config["api_features.log"]:
self.prepare_log_websocket()
log = logging.getLogger("maubot.init")
log.info(f"Initializing maubot {__version__}")
init_zip_loader(self.config)
init_db(self.config)
clients = init_client_class(self.config, self.loop)
self.add_startup_actions(*(client.start() for client in clients))
management_api = init_mgmt_api(self.config, self.loop)
self.server = MaubotServer(management_api, self.config, self.loop)
init_zip_loader(config)
db_engine = init_db(config)
clients = init_client_class(config, loop)
management_api = init_mgmt_api(config, loop)
server = MaubotServer(management_api, config, loop)
plugins = init_plugin_instance_class(config, server, loop)
plugins = init_plugin_instance_class(self.config, self.server, self.loop)
for plugin in plugins:
plugin.load()
for plugin in plugins:
plugin.load()
async def start(self) -> None:
if Client.crypto_db:
self.log.debug("Starting client crypto database")
await Client.crypto_db.start()
await super().start()
await self.server.start()
signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGTERM, signal.default_int_handler)
async def stop(self) -> None:
self.add_shutdown_actions(*(client.stop() for client in Client.cache.values()))
await super().stop()
self.log.debug("Stopping server")
try:
await asyncio.wait_for(self.server.stop(), 5)
except asyncio.TimeoutError:
self.log.warning("Stopping server timed out")
try:
log.info("Starting server")
loop.run_until_complete(server.start())
if Client.crypto_db:
log.debug("Starting client crypto database")
loop.run_until_complete(Client.crypto_db.start())
log.info("Starting clients and plugins")
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
log.info("Startup actions complete, running forever")
loop.run_forever()
except KeyboardInterrupt:
log.info("Interrupt received, stopping clients")
loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()]))
if stop_log_listener is not None:
log.debug("Closing websockets")
loop.run_until_complete(stop_log_listener())
log.debug("Stopping server")
try:
loop.run_until_complete(asyncio.wait_for(server.stop(), 5, loop=loop))
except asyncio.TimeoutError:
log.warning("Stopping server timed out")
log.debug("Closing event loop")
loop.close()
log.debug("Everything stopped, shutting down")
sys.exit(0)
Maubot().run()

View File

@ -14,17 +14,15 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING
from os import path
import asyncio
import logging
from aiohttp import ClientSession
from yarl import URL
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
from mautrix.errors import MatrixInvalidToken
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
PresenceState, StateFilter)
PresenceState, StateFilter, DeviceID)
from mautrix.client import InternalEventType
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
@ -33,13 +31,12 @@ from .db import DBClient
from .matrix import MaubotMatrixClient
try:
from mautrix.crypto import (OlmMachine, StateStore as CryptoStateStore, CryptoStore,
PickleCryptoStore)
from mautrix.crypto import OlmMachine, StateStore as CryptoStateStore, CryptoStore
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
pass
except ImportError:
except ImportError as e:
OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None
SQLStateStore = BaseSQLStateStore
@ -63,8 +60,7 @@ class Client:
cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None
global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore()
crypto_pickle_dir: str = None
crypto_db: 'AsyncDatabase' = None
crypto_db: Optional['AsyncDatabase'] = None
references: Set['PluginInstance']
db_instance: DBClient
@ -90,7 +86,7 @@ class Client:
log=self.log, loop=self.loop, device_id=self.device_id,
sync_store=SyncStoreProxy(self.db_instance),
state_store=self.global_state_store)
if OlmMachine and self.device_id and (self.crypto_db or self.crypto_pickle_dir):
if OlmMachine and self.device_id and self.crypto_db:
self.crypto_store = self._make_crypto_store()
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
self.client.crypto = self.crypto
@ -109,9 +105,6 @@ class Client:
def _make_crypto_store(self) -> 'CryptoStore':
if self.crypto_db:
return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db)
elif self.crypto_pickle_dir:
return PickleCryptoStore(account_id=self.id, pickle_key="maubot.crypto",
path=path.join(self.crypto_pickle_dir, f"{self.id}.pickle"))
raise ValueError("Crypto database not configured")
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
@ -330,7 +323,7 @@ class Client:
return self.db_instance.access_token
@property
def device_id(self) -> str:
def device_id(self) -> DeviceID:
return self.db_instance.device_id
@property
@ -403,25 +396,9 @@ def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.loop = loop
if OlmMachine:
db_type = config["crypto_database.type"]
if db_type == "default":
db_url = config["crypto_database"]
if db_url == "default":
db_url = config["database"]
parsed_url = URL(db_url)
if parsed_url.scheme == "sqlite":
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
elif parsed_url.scheme == "postgres" or parsed_url.scheme == "postgresql":
if not PgCryptoStore:
log.warning("Default database is postgres, but asyncpg is not installed. "
"Encryption will not work.")
else:
Client.crypto_db = AsyncDatabase(url=db_url,
upgrade_table=PgCryptoStore.upgrade_table)
elif db_type == "pickle":
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
elif (db_type == "postgres" or db_type == "postgresql") and PgCryptoStore:
Client.crypto_db = AsyncDatabase(url=config["crypto_database.postgres_uri"],
upgrade_table=PgCryptoStore.upgrade_table)
else:
raise ValueError("Unsupported crypto database type")
Client.crypto_db = AsyncDatabase.create(db_url, upgrade_table=PgCryptoStore.upgrade_table)
return Client.all()

View File

@ -32,9 +32,11 @@ class Config(BaseFileConfig):
base = helper.base
copy = helper.copy
copy("database")
copy("crypto_database.type")
copy("crypto_database.postgres_uri")
copy("crypto_database.pickle_dir")
if isinstance(self["crypto_database"], dict):
if self["crypto_database.type"] == "postgres":
base["crypto_database"] = self["crypto_database.postgres_uri"]
else:
copy("crypto_database")
copy("plugin_directories.upload")
copy("plugin_directories.load")
copy("plugin_directories.trash")

View File

@ -2,22 +2,11 @@
# Other DBMSes supported by SQLAlchemy may or may not work.
# Format examples:
# SQLite: sqlite:///filename.db
# Postgres: postgres://username:password@hostname/dbname
# Postgres: postgresql://username:password@hostname/dbname
database: sqlite:///maubot.db
# Database for encryption data.
crypto_database:
# Type of database. Either "default", "pickle" or "postgres".
# When set to default, using SQLite as the main database will use pickle as the crypto database
# and using Postgres as the main database will use the same one as the crypto database.
#
# When using pickle, individual crypto databases are stored in the pickle_dir directory.
# When using non-default postgres, postgres_uri is used to connect to postgres.
#
# WARNING: The pickle database is dangerous and should not be used in production.
type: default
postgres_uri: postgres://username:password@hostname/dbname
pickle_dir: ./crypto
# Separate database URL for the crypto database. "default" means use the same database as above.
crypto_database: default
plugin_directories:
# The directory where uploaded new plugins should be stored.

View File

@ -0,0 +1,9 @@
from typing import Callable, Awaitable, Generator, Any
class FutureAwaitable:
def __init__(self, func: Callable[[], Awaitable[None]]) -> None:
self._func = func
def __await__(self) -> Generator[Any, None, None]:
return self._func().__await__()

View File

@ -93,6 +93,7 @@ def init(loop: asyncio.AbstractEventLoop) -> None:
async def stop_all() -> None:
log.debug("Closing log listener websockets")
log_root.removeHandler(handler)
for socket in sockets:
try:

View File

@ -3,9 +3,10 @@
#/postgres
psycopg2-binary>=2,<3
asyncpg>=0.20,<0.26
#/e2be
asyncpg>=0.20,<0.25
aiosqlite>=0.16,<0.18
python-olm>=3,<4
pycryptodome>=3,<4
unpaddedbase64>=1,<2

View File

@ -1,4 +1,4 @@
mautrix>=0.10.9,<0.11
mautrix>=0.11,<0.12
aiohttp>=3,<4
yarl>=1,<2
SQLAlchemy>=1,<1.4

View File

@ -57,15 +57,18 @@ setuptools.setup(
mbc=maubot.cli:app
""",
data_files=[
(".", ["example-config.yaml", "alembic.ini"]),
(".", ["maubot/example-config.yaml", "alembic.ini"]),
("alembic", ["alembic/env.py"]),
("alembic/versions", glob.glob("alembic/versions/*.py")),
],
package_data={
"maubot": ["management/frontend/build/*",
"management/frontend/build/static/css/*",
"management/frontend/build/static/js/*",
"management/frontend/build/static/media/*"],
"maubot": [
"example-config.yaml",
"management/frontend/build/*",
"management/frontend/build/static/css/*",
"management/frontend/build/static/js/*",
"management/frontend/build/static/media/*",
],
"maubot.cli": ["res/*"],
},
)