diff --git a/maubot/__main__.py b/maubot/__main__.py
index 2ef73f9..f07b58c 100644
--- a/maubot/__main__.py
+++ b/maubot/__main__.py
@@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
-# Copyright (C) 2019 Tulir Asokan
+# Copyright (C) 2021 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -13,12 +13,9 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-import logging.config
-import argparse
import asyncio
-import signal
-import copy
-import sys
+
+from mautrix.util.program import Program
from .config import Config
from .db import init as init_db
@@ -27,70 +24,58 @@ from .client import Client, init as init_client_class
from .loader.zip import init as init_zip_loader
from .instance import init as init_plugin_instance_class
from .management.api import init as init_mgmt_api
+from .lib.future_awaitable import FutureAwaitable
from .__meta__ import __version__
-parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
- prog="python -m maubot")
-parser.add_argument("-c", "--config", type=str, default="config.yaml",
- metavar="", help="the path to your config file")
-parser.add_argument("-b", "--base-config", type=str, default="example-config.yaml",
- metavar="", help="the path to the example config "
- "(for automatic config updates)")
-args = parser.parse_args()
-config = Config(args.config, args.base_config)
-config.load()
-config.update()
+class Maubot(Program):
+ config: Config
+ server: MaubotServer
-logging.config.dictConfig(copy.deepcopy(config["logging"]))
+ config_class = Config
+ module = "maubot"
+ name = "maubot"
+ version = __version__
+ command = "python -m maubot"
+ description = "A plugin-based Matrix bot system."
-loop = asyncio.get_event_loop()
+ def prepare_log_websocket(self) -> None:
+ from .management.api.log import init, stop_all
+ init(self.loop)
+ self.add_shutdown_actions(FutureAwaitable(stop_all))
-stop_log_listener = None
-if config["api_features.log"]:
- from .management.api.log import init as init_log_listener, stop_all as stop_log_listener
+ def prepare(self) -> None:
+ super().prepare()
- init_log_listener(loop)
+ if self.config["api_features.log"]:
+ self.prepare_log_websocket()
-log = logging.getLogger("maubot.init")
-log.info(f"Initializing maubot {__version__}")
+ init_zip_loader(self.config)
+ init_db(self.config)
+ clients = init_client_class(self.config, self.loop)
+ self.add_startup_actions(*(client.start() for client in clients))
+ management_api = init_mgmt_api(self.config, self.loop)
+ self.server = MaubotServer(management_api, self.config, self.loop)
-init_zip_loader(config)
-db_engine = init_db(config)
-clients = init_client_class(config, loop)
-management_api = init_mgmt_api(config, loop)
-server = MaubotServer(management_api, config, loop)
-plugins = init_plugin_instance_class(config, server, loop)
+ plugins = init_plugin_instance_class(self.config, self.server, self.loop)
+ for plugin in plugins:
+ plugin.load()
-for plugin in plugins:
- plugin.load()
+ async def start(self) -> None:
+ if Client.crypto_db:
+ self.log.debug("Starting client crypto database")
+ await Client.crypto_db.start()
+ await super().start()
+ await self.server.start()
-signal.signal(signal.SIGINT, signal.default_int_handler)
-signal.signal(signal.SIGTERM, signal.default_int_handler)
+ async def stop(self) -> None:
+ self.add_shutdown_actions(*(client.stop() for client in Client.cache.values()))
+ await super().stop()
+ self.log.debug("Stopping server")
+ try:
+ await asyncio.wait_for(self.server.stop(), 5)
+ except asyncio.TimeoutError:
+ self.log.warning("Stopping server timed out")
-try:
- log.info("Starting server")
- loop.run_until_complete(server.start())
- if Client.crypto_db:
- log.debug("Starting client crypto database")
- loop.run_until_complete(Client.crypto_db.start())
- log.info("Starting clients and plugins")
- loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
- log.info("Startup actions complete, running forever")
- loop.run_forever()
-except KeyboardInterrupt:
- log.info("Interrupt received, stopping clients")
- loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()]))
- if stop_log_listener is not None:
- log.debug("Closing websockets")
- loop.run_until_complete(stop_log_listener())
- log.debug("Stopping server")
- try:
- loop.run_until_complete(asyncio.wait_for(server.stop(), 5, loop=loop))
- except asyncio.TimeoutError:
- log.warning("Stopping server timed out")
- log.debug("Closing event loop")
- loop.close()
- log.debug("Everything stopped, shutting down")
- sys.exit(0)
+Maubot().run()
diff --git a/maubot/client.py b/maubot/client.py
index 0ffcd00..60fd335 100644
--- a/maubot/client.py
+++ b/maubot/client.py
@@ -14,17 +14,15 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING
-from os import path
import asyncio
import logging
from aiohttp import ClientSession
-from yarl import URL
-from mautrix.errors import MatrixInvalidToken, MatrixRequestError
+from mautrix.errors import MatrixInvalidToken
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
- PresenceState, StateFilter)
+ PresenceState, StateFilter, DeviceID)
from mautrix.client import InternalEventType
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
@@ -33,13 +31,12 @@ from .db import DBClient
from .matrix import MaubotMatrixClient
try:
- from mautrix.crypto import (OlmMachine, StateStore as CryptoStateStore, CryptoStore,
- PickleCryptoStore)
+ from mautrix.crypto import OlmMachine, StateStore as CryptoStateStore, CryptoStore
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
pass
-except ImportError:
+except ImportError as e:
OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None
SQLStateStore = BaseSQLStateStore
@@ -63,8 +60,7 @@ class Client:
cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None
global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore()
- crypto_pickle_dir: str = None
- crypto_db: 'AsyncDatabase' = None
+ crypto_db: Optional['AsyncDatabase'] = None
references: Set['PluginInstance']
db_instance: DBClient
@@ -90,7 +86,7 @@ class Client:
log=self.log, loop=self.loop, device_id=self.device_id,
sync_store=SyncStoreProxy(self.db_instance),
state_store=self.global_state_store)
- if OlmMachine and self.device_id and (self.crypto_db or self.crypto_pickle_dir):
+ if OlmMachine and self.device_id and self.crypto_db:
self.crypto_store = self._make_crypto_store()
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
self.client.crypto = self.crypto
@@ -109,9 +105,6 @@ class Client:
def _make_crypto_store(self) -> 'CryptoStore':
if self.crypto_db:
return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db)
- elif self.crypto_pickle_dir:
- return PickleCryptoStore(account_id=self.id, pickle_key="maubot.crypto",
- path=path.join(self.crypto_pickle_dir, f"{self.id}.pickle"))
raise ValueError("Crypto database not configured")
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
@@ -330,7 +323,7 @@ class Client:
return self.db_instance.access_token
@property
- def device_id(self) -> str:
+ def device_id(self) -> DeviceID:
return self.db_instance.device_id
@property
@@ -403,25 +396,9 @@ def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.loop = loop
if OlmMachine:
- db_type = config["crypto_database.type"]
- if db_type == "default":
+ db_url = config["crypto_database"]
+ if db_url == "default":
db_url = config["database"]
- parsed_url = URL(db_url)
- if parsed_url.scheme == "sqlite":
- Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
- elif parsed_url.scheme == "postgres" or parsed_url.scheme == "postgresql":
- if not PgCryptoStore:
- log.warning("Default database is postgres, but asyncpg is not installed. "
- "Encryption will not work.")
- else:
- Client.crypto_db = AsyncDatabase(url=db_url,
- upgrade_table=PgCryptoStore.upgrade_table)
- elif db_type == "pickle":
- Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
- elif (db_type == "postgres" or db_type == "postgresql") and PgCryptoStore:
- Client.crypto_db = AsyncDatabase(url=config["crypto_database.postgres_uri"],
- upgrade_table=PgCryptoStore.upgrade_table)
- else:
- raise ValueError("Unsupported crypto database type")
+ Client.crypto_db = AsyncDatabase.create(db_url, upgrade_table=PgCryptoStore.upgrade_table)
return Client.all()
diff --git a/maubot/config.py b/maubot/config.py
index 34466cc..eb5072b 100644
--- a/maubot/config.py
+++ b/maubot/config.py
@@ -32,9 +32,11 @@ class Config(BaseFileConfig):
base = helper.base
copy = helper.copy
copy("database")
- copy("crypto_database.type")
- copy("crypto_database.postgres_uri")
- copy("crypto_database.pickle_dir")
+ if isinstance(self["crypto_database"], dict):
+ if self["crypto_database.type"] == "postgres":
+ base["crypto_database"] = self["crypto_database.postgres_uri"]
+ else:
+ copy("crypto_database")
copy("plugin_directories.upload")
copy("plugin_directories.load")
copy("plugin_directories.trash")
diff --git a/example-config.yaml b/maubot/example-config.yaml
similarity index 81%
rename from example-config.yaml
rename to maubot/example-config.yaml
index fbb0183..8f3e288 100644
--- a/example-config.yaml
+++ b/maubot/example-config.yaml
@@ -2,22 +2,11 @@
# Other DBMSes supported by SQLAlchemy may or may not work.
# Format examples:
# SQLite: sqlite:///filename.db
-# Postgres: postgres://username:password@hostname/dbname
+# Postgres: postgresql://username:password@hostname/dbname
database: sqlite:///maubot.db
-# Database for encryption data.
-crypto_database:
- # Type of database. Either "default", "pickle" or "postgres".
- # When set to default, using SQLite as the main database will use pickle as the crypto database
- # and using Postgres as the main database will use the same one as the crypto database.
- #
- # When using pickle, individual crypto databases are stored in the pickle_dir directory.
- # When using non-default postgres, postgres_uri is used to connect to postgres.
- #
- # WARNING: The pickle database is dangerous and should not be used in production.
- type: default
- postgres_uri: postgres://username:password@hostname/dbname
- pickle_dir: ./crypto
+# Separate database URL for the crypto database. "default" means use the same database as above.
+crypto_database: default
plugin_directories:
# The directory where uploaded new plugins should be stored.
diff --git a/maubot/lib/future_awaitable.py b/maubot/lib/future_awaitable.py
new file mode 100644
index 0000000..b55dcb6
--- /dev/null
+++ b/maubot/lib/future_awaitable.py
@@ -0,0 +1,9 @@
+from typing import Callable, Awaitable, Generator, Any
+
+class FutureAwaitable:
+ def __init__(self, func: Callable[[], Awaitable[None]]) -> None:
+ self._func = func
+
+ def __await__(self) -> Generator[Any, None, None]:
+ return self._func().__await__()
+
diff --git a/maubot/management/api/log.py b/maubot/management/api/log.py
index d6ec092..3ed5ca1 100644
--- a/maubot/management/api/log.py
+++ b/maubot/management/api/log.py
@@ -93,6 +93,7 @@ def init(loop: asyncio.AbstractEventLoop) -> None:
async def stop_all() -> None:
+ log.debug("Closing log listener websockets")
log_root.removeHandler(handler)
for socket in sockets:
try:
diff --git a/optional-requirements.txt b/optional-requirements.txt
index eeef618..0ab6975 100644
--- a/optional-requirements.txt
+++ b/optional-requirements.txt
@@ -3,9 +3,10 @@
#/postgres
psycopg2-binary>=2,<3
+asyncpg>=0.20,<0.26
#/e2be
-asyncpg>=0.20,<0.25
+aiosqlite>=0.16,<0.18
python-olm>=3,<4
pycryptodome>=3,<4
unpaddedbase64>=1,<2
diff --git a/requirements.txt b/requirements.txt
index 3420a56..bc888b7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-mautrix>=0.10.9,<0.11
+mautrix>=0.11,<0.12
aiohttp>=3,<4
yarl>=1,<2
SQLAlchemy>=1,<1.4
diff --git a/setup.py b/setup.py
index 9286b8a..0fd99ff 100644
--- a/setup.py
+++ b/setup.py
@@ -57,15 +57,18 @@ setuptools.setup(
mbc=maubot.cli:app
""",
data_files=[
- (".", ["example-config.yaml", "alembic.ini"]),
+ (".", ["maubot/example-config.yaml", "alembic.ini"]),
("alembic", ["alembic/env.py"]),
("alembic/versions", glob.glob("alembic/versions/*.py")),
],
package_data={
- "maubot": ["management/frontend/build/*",
- "management/frontend/build/static/css/*",
- "management/frontend/build/static/js/*",
- "management/frontend/build/static/media/*"],
+ "maubot": [
+ "example-config.yaml",
+ "management/frontend/build/*",
+ "management/frontend/build/static/css/*",
+ "management/frontend/build/static/js/*",
+ "management/frontend/build/static/media/*",
+ ],
"maubot.cli": ["res/*"],
},
)