mirror of
https://github.com/maubot/maubot.git
synced 2024-10-01 01:06:10 -04:00
Refactor how plugins are started and update spec
This commit is contained in:
parent
b96d6e6a94
commit
9e066478a9
@ -57,7 +57,7 @@ Base.metadata.create_all()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
init_db(db_session)
|
||||
init_client(loop)
|
||||
clients = init_client(loop)
|
||||
init_plugin_instance_class(db_session, config)
|
||||
management_api = init_management(config, loop)
|
||||
server = MaubotServer(config, management_api, loop)
|
||||
@ -84,9 +84,10 @@ async def periodic_commit():
|
||||
|
||||
|
||||
try:
|
||||
loop.run_until_complete(asyncio.gather(
|
||||
server.start(),
|
||||
*[plugin.start() for plugin in plugins]))
|
||||
log.debug("Starting server")
|
||||
loop.run_until_complete(server.start())
|
||||
log.debug("Starting clients and plugins")
|
||||
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
|
||||
log.debug("Startup actions complete, running forever")
|
||||
loop.run_until_complete(periodic_commit())
|
||||
loop.run_forever()
|
||||
|
@ -1 +1 @@
|
||||
__version__ = "0.1.0.dev5"
|
||||
__version__ = "0.1.0.dev6"
|
||||
|
106
maubot/client.py
106
maubot/client.py
@ -18,6 +18,7 @@ from aiohttp import ClientSession
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
|
||||
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
|
||||
EventType, Filter, RoomFilter, RoomEventFilter)
|
||||
|
||||
@ -31,6 +32,7 @@ log = logging.getLogger("maubot.client")
|
||||
|
||||
|
||||
class Client:
|
||||
log: logging.Logger
|
||||
loop: asyncio.AbstractEventLoop
|
||||
cache: Dict[UserID, 'Client'] = {}
|
||||
http_client: ClientSession = None
|
||||
@ -38,42 +40,97 @@ class Client:
|
||||
references: Set['PluginInstance']
|
||||
db_instance: DBClient
|
||||
client: MaubotMatrixClient
|
||||
started: bool
|
||||
|
||||
def __init__(self, db_instance: DBClient) -> None:
|
||||
self.db_instance = db_instance
|
||||
self.cache[self.id] = self
|
||||
self.log = log.getChild(self.id)
|
||||
self.references = set()
|
||||
self.started = False
|
||||
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
|
||||
token=self.access_token, client_session=self.http_client,
|
||||
log=self.log, loop=self.loop, store=self.db_instance)
|
||||
if self.autojoin:
|
||||
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
||||
|
||||
def start(self) -> None:
|
||||
asyncio.ensure_future(self._start(), loop=self.loop)
|
||||
|
||||
async def _start(self) -> None:
|
||||
async def start(self, try_n: Optional[int] = 0) -> None:
|
||||
try:
|
||||
if not self.filter_id:
|
||||
self.filter_id = await self.client.create_filter(Filter(
|
||||
room=RoomFilter(
|
||||
timeline=RoomEventFilter(
|
||||
limit=50,
|
||||
),
|
||||
),
|
||||
))
|
||||
if self.displayname != "disable":
|
||||
await self.client.set_displayname(self.displayname)
|
||||
if self.avatar_url != "disable":
|
||||
await self.client.set_avatar_url(self.avatar_url)
|
||||
await self.client.start(self.filter_id)
|
||||
if try_n > 0:
|
||||
await asyncio.sleep(try_n * 10)
|
||||
await self._start(try_n)
|
||||
except Exception:
|
||||
self.log.exception("starting raised exception")
|
||||
self.log.exception("Failed to start")
|
||||
|
||||
async def _start(self, try_n: Optional[int] = 0) -> None:
|
||||
if not self.enabled:
|
||||
self.log.debug("Not starting disabled client")
|
||||
return
|
||||
elif self.started:
|
||||
self.log.warning("Ignoring start() call to started client")
|
||||
return
|
||||
try:
|
||||
user_id = await self.client.whoami()
|
||||
except MatrixInvalidToken as e:
|
||||
self.log.error(f"Invalid token: {e}. Disabling client")
|
||||
self.enabled = False
|
||||
return
|
||||
except MatrixRequestError:
|
||||
if try_n >= 5:
|
||||
self.log.exception("Failed to get /account/whoami, disabling client")
|
||||
self.enabled = False
|
||||
else:
|
||||
self.log.exception(f"Failed to get /account/whoami, "
|
||||
f"retrying in {(try_n + 1) * 10}s")
|
||||
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
|
||||
return
|
||||
if user_id != self.id:
|
||||
self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}")
|
||||
self.enabled = False
|
||||
return
|
||||
if not self.filter_id:
|
||||
self.filter_id = await self.client.create_filter(Filter(
|
||||
room=RoomFilter(
|
||||
timeline=RoomEventFilter(
|
||||
limit=50,
|
||||
),
|
||||
),
|
||||
))
|
||||
if self.displayname != "disable":
|
||||
await self.client.set_displayname(self.displayname)
|
||||
if self.avatar_url != "disable":
|
||||
await self.client.set_avatar_url(self.avatar_url)
|
||||
if self.sync:
|
||||
self.client.start(self.filter_id)
|
||||
self.started = True
|
||||
self.log.info("Client started, starting plugin instances...")
|
||||
await self.start_plugins()
|
||||
|
||||
async def start_plugins(self) -> None:
|
||||
await asyncio.gather(*[plugin.start() for plugin in self.references], loop=self.loop)
|
||||
|
||||
async def stop_plugins(self) -> None:
|
||||
await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.running],
|
||||
loop=self.loop)
|
||||
|
||||
def stop(self) -> None:
|
||||
self.started = False
|
||||
self.client.stop()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"homeserver": self.homeserver,
|
||||
"access_token": self.access_token,
|
||||
"enabled": self.enabled,
|
||||
"started": self.started,
|
||||
"sync": self.sync,
|
||||
"autojoin": self.autojoin,
|
||||
"displayname": self.displayname,
|
||||
"avatar_url": self.avatar_url,
|
||||
"instances": [instance.to_dict() for instance in self.references],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
|
||||
try:
|
||||
@ -111,6 +168,14 @@ class Client:
|
||||
self.client.api.token = value
|
||||
self.db_instance.access_token = value
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.db_instance.enabled
|
||||
|
||||
@enabled.setter
|
||||
def enabled(self, value: bool) -> None:
|
||||
self.db_instance.enabled = value
|
||||
|
||||
@property
|
||||
def next_batch(self) -> SyncToken:
|
||||
return self.db_instance.next_batch
|
||||
@ -168,8 +233,7 @@ class Client:
|
||||
# endregion
|
||||
|
||||
|
||||
def init(loop: asyncio.AbstractEventLoop) -> None:
|
||||
def init(loop: asyncio.AbstractEventLoop) -> List[Client]:
|
||||
Client.http_client = ClientSession(loop=loop)
|
||||
Client.loop = loop
|
||||
for client in Client.all():
|
||||
client.start()
|
||||
return Client.all()
|
||||
|
@ -42,6 +42,7 @@ class DBClient(Base):
|
||||
id: UserID = Column(String(255), primary_key=True)
|
||||
homeserver: str = Column(String(255), nullable=False)
|
||||
access_token: str = Column(String(255), nullable=False)
|
||||
enabled: bool = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
next_batch: SyncToken = Column(String(255), nullable=False, default="")
|
||||
filter_id: FilterID = Column(String(255), nullable=False, default="")
|
||||
|
@ -75,6 +75,7 @@ class PluginInstance:
|
||||
if not self.client:
|
||||
self.log.error(f"Failed to get client for user {self.primary_user}")
|
||||
self.enabled = False
|
||||
return
|
||||
self.log.debug("Plugin instance dependencies loaded")
|
||||
self.loader.references.add(self)
|
||||
self.client.references.add(self)
|
||||
@ -93,8 +94,11 @@ class PluginInstance:
|
||||
self.db_instance.config = buf.getvalue()
|
||||
|
||||
async def start(self) -> None:
|
||||
if not self.enabled:
|
||||
self.log.warning(f"Plugin disabled, not starting.")
|
||||
if self.running:
|
||||
self.log.warning("Ignoring start() call to already started plugin")
|
||||
return
|
||||
elif not self.enabled:
|
||||
self.log.warning("Plugin disabled, not starting.")
|
||||
return
|
||||
cls = await self.loader.load()
|
||||
config_class = cls.get_config_class()
|
||||
@ -118,6 +122,9 @@ class PluginInstance:
|
||||
f"with user {self.client.id}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
if not self.running:
|
||||
self.log.warning("Ignoring stop() call to non-running plugin")
|
||||
return
|
||||
self.log.debug("Stopping plugin instance...")
|
||||
self.running = False
|
||||
await self.plugin.stop()
|
||||
|
@ -47,6 +47,7 @@ class PluginLoader(ABC):
|
||||
return {
|
||||
"id": self.id,
|
||||
"version": self.version,
|
||||
"instances": [instance.to_dict() for instance in self.references],
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -14,26 +14,56 @@
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from aiohttp import web
|
||||
from json import JSONDecodeError
|
||||
|
||||
from mautrix.types import UserID
|
||||
|
||||
from ...db import DBClient
|
||||
from ...client import Client
|
||||
from .base import routes
|
||||
from .responses import ErrNotImplemented
|
||||
from .responses import ErrNotImplemented, ErrClientNotFound, ErrBodyNotJSON
|
||||
|
||||
|
||||
@routes.get("/instances")
|
||||
def get_instances(request: web.Request) -> web.Response:
|
||||
return ErrNotImplemented
|
||||
async def get_instances(_: web.Request) -> web.Response:
|
||||
return web.json_response([client.to_dict() for client in Client.cache.values()])
|
||||
|
||||
|
||||
@routes.get("/instance/{id}")
|
||||
def get_instance(request: web.Request) -> web.Response:
|
||||
async def get_instance(request: web.Request) -> web.Response:
|
||||
user_id = request.match_info.get("id", None)
|
||||
client = Client.get(user_id, None)
|
||||
if not client:
|
||||
return ErrClientNotFound
|
||||
return web.json_response(client.to_dict())
|
||||
|
||||
|
||||
async def create_instance(user_id: UserID, data: dict) -> web.Response:
|
||||
return ErrNotImplemented
|
||||
|
||||
|
||||
async def update_instance(client: Client, data: dict) -> web.Response:
|
||||
return ErrNotImplemented
|
||||
|
||||
|
||||
@routes.put("/instance/{id}")
|
||||
def update_instance(request: web.Request) -> web.Response:
|
||||
return ErrNotImplemented
|
||||
async def update_instance(request: web.Request) -> web.Response:
|
||||
user_id = request.match_info.get("id", None)
|
||||
client = Client.get(user_id, None)
|
||||
try:
|
||||
data = await request.json()
|
||||
except JSONDecodeError:
|
||||
return ErrBodyNotJSON
|
||||
if not client:
|
||||
return await create_instance(user_id, data)
|
||||
else:
|
||||
return await update_instance(client, data)
|
||||
|
||||
|
||||
@routes.delete("/instance/{id}")
|
||||
def delete_instance(request: web.Request) -> web.Response:
|
||||
async def delete_instance(request: web.Request) -> web.Response:
|
||||
user_id = request.match_info.get("id", None)
|
||||
client = Client.get(user_id, None)
|
||||
if not client:
|
||||
return ErrClientNotFound
|
||||
return ErrNotImplemented
|
||||
|
@ -24,16 +24,9 @@ from .responses import (ErrPluginNotFound, ErrPluginInUse, plugin_import_error,
|
||||
from .base import routes, get_config
|
||||
|
||||
|
||||
def _plugin_to_dict(plugin: PluginLoader) -> dict:
|
||||
return {
|
||||
**plugin.to_dict(),
|
||||
"instances": [instance.to_dict() for instance in plugin.references]
|
||||
}
|
||||
|
||||
|
||||
@routes.get("/plugins")
|
||||
async def get_plugins(_) -> web.Response:
|
||||
return web.json_response([_plugin_to_dict(plugin) for plugin in PluginLoader.id_cache.values()])
|
||||
return web.json_response([plugin.to_dict() for plugin in PluginLoader.id_cache.values()])
|
||||
|
||||
|
||||
@routes.get("/plugin/{id}")
|
||||
@ -42,7 +35,7 @@ async def get_plugin(request: web.Request) -> web.Response:
|
||||
plugin = PluginLoader.id_cache.get(plugin_id, None)
|
||||
if not plugin:
|
||||
return ErrPluginNotFound
|
||||
return web.json_response(_plugin_to_dict(plugin))
|
||||
return web.json_response(plugin.to_dict())
|
||||
|
||||
|
||||
@routes.delete("/plugin/{id}")
|
||||
@ -78,11 +71,11 @@ async def upload_new_plugin(content: bytes, pid: str, version: str) -> web.Respo
|
||||
with open(path, "wb") as p:
|
||||
p.write(content)
|
||||
try:
|
||||
ZippedPluginLoader.get(path)
|
||||
plugin = ZippedPluginLoader.get(path)
|
||||
except MaubotZipImportError as e:
|
||||
ZippedPluginLoader.trash(path)
|
||||
return plugin_import_error(str(e), traceback.format_exc())
|
||||
return RespOK
|
||||
return web.json_response(plugin.to_dict())
|
||||
|
||||
|
||||
async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes, new_version: str
|
||||
@ -110,7 +103,7 @@ async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes,
|
||||
return plugin_import_error(str(e), traceback.format_exc())
|
||||
await plugin.start_instances()
|
||||
ZippedPluginLoader.trash(old_path, reason="update")
|
||||
return RespOK
|
||||
return web.json_response(plugin.to_dict())
|
||||
|
||||
|
||||
@routes.post("/plugins/upload")
|
||||
|
@ -36,6 +36,11 @@ ErrPluginNotFound = web.json_response({
|
||||
"errcode": "plugin_not_found",
|
||||
}, status=HTTPStatus.NOT_FOUND)
|
||||
|
||||
ErrClientNotFound = web.json_response({
|
||||
"error": "Client not found",
|
||||
"errcode": "client_not_found",
|
||||
}, status=HTTPStatus.NOT_FOUND)
|
||||
|
||||
ErrPathNotFound = web.json_response({
|
||||
"error": "Resource not found",
|
||||
"errcode": "resource_not_found",
|
||||
|
@ -231,12 +231,14 @@ paths:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/MatrixClientList'
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/MatrixClient'
|
||||
401:
|
||||
$ref: '#/components/responses/Unauthorized'
|
||||
'/client/{user_id}':
|
||||
'/client/{id}':
|
||||
parameters:
|
||||
- name: user_id
|
||||
- name: id
|
||||
in: path
|
||||
description: The Matrix user ID of the client to get
|
||||
required: true
|
||||
@ -338,38 +340,12 @@ components:
|
||||
enabled:
|
||||
type: boolean
|
||||
example: true
|
||||
started:
|
||||
type: boolean
|
||||
example: true
|
||||
primary_user:
|
||||
type: string
|
||||
example: '@putkiteippi:maunium.net'
|
||||
MatrixClientList:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
example: '@putkiteippi:maunium.net'
|
||||
homeserver:
|
||||
type: string
|
||||
example: 'https://maunium.net'
|
||||
enabled:
|
||||
type: boolean
|
||||
example: true
|
||||
sync:
|
||||
type: boolean
|
||||
example: true
|
||||
autojoin:
|
||||
type: boolean
|
||||
example: true
|
||||
displayname:
|
||||
type: string
|
||||
example: J. E. Saarinen
|
||||
avatar_url:
|
||||
type: string
|
||||
example: 'mxc://maunium.net/FsPQQTntCCqhJMFtwArmJdaU'
|
||||
instance_count:
|
||||
type: integer
|
||||
example: 1
|
||||
MatrixClient:
|
||||
type: object
|
||||
properties:
|
||||
@ -385,6 +361,9 @@ components:
|
||||
enabled:
|
||||
type: boolean
|
||||
example: true
|
||||
started:
|
||||
type: boolean
|
||||
example: true
|
||||
sync:
|
||||
type: boolean
|
||||
example: true
|
||||
|
Loading…
Reference in New Issue
Block a user