More changes

This commit is contained in:
Tulir Asokan 2018-10-16 16:41:02 +03:00
parent 0b246e44a8
commit eef052b1e9
9 changed files with 195 additions and 61 deletions

View File

@ -10,8 +10,9 @@ plugin_directories:
- ./plugins
server:
# The IP:port to listen to.
listen: 0.0.0.0:29316
# The IP and port to listen to.
hostname: 0.0.0.0
port: 29316
# The base management API path.
base_path: /_matrix/maubot
# The base appservice API path. Use / for legacy appservice API and /_matrix/app/v1 for v1.

View File

@ -17,9 +17,14 @@ from sqlalchemy import orm
import sqlalchemy as sql
import logging.config
import argparse
import asyncio
import copy
import sys
from .config import Config
from .db import Base, init as init_db
from .server import MaubotServer
from .client import Client, init as init_client
from .__meta__ import __version__
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
@ -36,7 +41,23 @@ logging.config.dictConfig(copy.deepcopy(config["logging"]))
log = logging.getLogger("maubot")
log.debug(f"Initializing maubot {__version__}")
db_engine = sql.create_engine(config["database"])
db_engine: sql.engine.Engine = sql.create_engine(config["database"])
db_factory = orm.sessionmaker(bind=db_engine)
db_session = orm.scoping.scoped_session(db_factory)
Base.metadata.bind=db_engine
loop = asyncio.get_event_loop()
init_db(db_session)
init_client(loop)
server = MaubotServer(config, loop)
try:
loop.run_until_complete(server.start())
loop.run_forever()
except KeyboardInterrupt:
log.debug("Keyboard interrupt received, stopping...")
for client in Client.cache.values():
client.stop()
loop.run_until_complete(server.stop())
sys.exit(0)

View File

@ -13,62 +13,21 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional, Union, Callable
from typing import Dict, List, Optional
from aiohttp import ClientSession
import asyncio
import logging
from mautrix import Client as MatrixClient
from mautrix.client import EventHandler
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership,
EventType, MessageEvent)
from mautrix.types import UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership, EventType
from .command_spec import ParsedCommand
from .db import DBClient
from .matrix import MaubotMatrixClient
log = logging.getLogger("maubot.client")
class MaubotMatrixClient(MatrixClient):
def __init__(self, maubot_client: 'Client', *args, **kwargs):
super().__init__(*args, **kwargs)
self._maubot_client = maubot_client
self.command_handlers: Dict[str, List[EventHandler]] = {}
self.commands: List[ParsedCommand] = []
self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE)
async def _command_event_handler(self, evt: MessageEvent) -> None:
for command in self.commands:
if command.match(evt):
await self._trigger_command(command, evt)
return
async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None:
for handler in self.command_handlers.get(command.name, []):
await handler(evt)
def on(self, var: Union[EventHandler, EventType, str]
) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]:
if isinstance(var, str):
def decorator(func: EventHandler) -> EventHandler:
self.add_command_handler(var, func)
return func
return decorator
return super().on(var)
def add_command_handler(self, command: str, handler: EventHandler) -> None:
self.command_handlers.setdefault(command, []).append(handler)
def remove_command_handler(self, command: str, handler: EventHandler) -> None:
try:
self.command_handlers[command].remove(handler)
except (KeyError, ValueError):
pass
class Client:
loop: asyncio.AbstractEventLoop
cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None
@ -78,26 +37,33 @@ class Client:
def __init__(self, db_instance: DBClient) -> None:
self.db_instance = db_instance
self.cache[self.id] = self
self.client = MaubotMatrixClient(maubot_client=self,
store=self.db_instance,
mxid=self.id,
base_url=self.homeserver,
token=self.access_token,
client_session=self.http_client,
self.client = MaubotMatrixClient(maubot_client=self, store=self.db_instance,
mxid=self.id, base_url=self.homeserver,
token=self.access_token, client_session=self.http_client,
log=log.getChild(self.id))
if self.autojoin:
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
def start(self) -> None:
asyncio.ensure_future(self.client.start(), loop=self.loop)
def stop(self) -> None:
self.client.stop()
@classmethod
def get(cls, user_id: UserID) -> Optional['Client']:
def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
try:
return cls.cache[user_id]
except KeyError:
db_instance = DBClient.query.get(user_id)
db_instance = db_instance or DBClient.query.get(user_id)
if not db_instance:
return None
return Client(db_instance)
@classmethod
def all(cls) -> List['Client']:
return [cls.get(user.id, user) for user in DBClient.query.all()]
# region Properties
@property
@ -176,3 +142,9 @@ class Client:
async def _handle_invite(self, evt: StateEvent) -> None:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room_by_id(evt.room_id)
def init(loop: asyncio.AbstractEventLoop) -> None:
Client.loop = loop
for client in Client.all():
client.start()

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type
from sqlalchemy import (Column, String, Boolean, ForeignKey, Text, TypeDecorator)
from sqlalchemy.orm import Query
from sqlalchemy.orm import Query, scoped_session
from sqlalchemy.ext.declarative import declarative_base
import json
@ -89,3 +89,9 @@ class DBCommandSpec(Base):
ForeignKey("client.id", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True)
spec: CommandSpec = Column(make_serializable_alchemy(CommandSpec), nullable=False)
def init(session: scoped_session) -> None:
DBPlugin.query = session.query_property()
DBClient.query = session.query_property()
DBCommandSpec.query = session.query_property()

View File

@ -13,14 +13,17 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import TypeVar, Type
from typing import TypeVar, Type, Dict
from abc import ABC, abstractmethod
from ..plugin_base import Plugin
PluginClass = TypeVar("PluginClass", bound=Plugin)
class PluginLoader(ABC):
id_cache: Dict[str, 'PluginLoader'] = {}
id: str
version: str

View File

@ -29,7 +29,6 @@ class MaubotZipImportError(Exception):
class ZippedPluginLoader(PluginLoader):
path_cache: Dict[str, 'ZippedPluginLoader'] = {}
id_cache: Dict[str, 'ZippedPluginLoader'] = {}
path: str
id: str

77
maubot/matrix.py Normal file
View File

@ -0,0 +1,77 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Union, Callable
from mautrix import Client as MatrixClient
from mautrix.client import EventHandler
from mautrix.types import EventType, MessageEvent
from .command_spec import ParsedCommand, CommandSpec
class MaubotMatrixClient(MatrixClient):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.command_handlers: Dict[str, List[EventHandler]] = {}
self.commands: List[ParsedCommand] = []
self.command_specs: Dict[str, CommandSpec] = {}
self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE)
def set_command_spec(self, plugin_id: str, spec: CommandSpec) -> None:
self.command_specs[plugin_id] = spec
self._reparse_command_specs()
def _reparse_command_specs(self) -> None:
self.commands = [parsed_command
for spec in self.command_specs.values()
for parsed_command in spec.parse()]
def remove_command_spec(self, plugin_id: str) -> None:
try:
del self.command_specs[plugin_id]
self._reparse_command_specs()
except KeyError:
pass
async def _command_event_handler(self, evt: MessageEvent) -> None:
for command in self.commands:
if command.match(evt):
await self._trigger_command(command, evt)
return
async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None:
for handler in self.command_handlers.get(command.name, []):
await handler(evt)
def on(self, var: Union[EventHandler, EventType, str]
) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]:
if isinstance(var, str):
def decorator(func: EventHandler) -> EventHandler:
self.add_command_handler(var, func)
return func
return decorator
return super().on(var)
def add_command_handler(self, command: str, handler: EventHandler) -> None:
self.command_handlers.setdefault(command, []).append(handler)
def remove_command_handler(self, command: str, handler: EventHandler) -> None:
try:
self.command_handlers[command].remove(handler)
except (KeyError, ValueError):
pass

View File

@ -22,11 +22,12 @@ if TYPE_CHECKING:
class Plugin(ABC):
def __init__(self, client: 'MaubotMatrixClient') -> None:
def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str) -> None:
self.client = client
self.id = plugin_instance_id
def set_command_spec(self, spec: 'CommandSpec') -> None:
pass
self.client.set_command_spec(self.id, spec)
async def start(self) -> None:
pass

54
maubot/server.py Normal file
View File

@ -0,0 +1,54 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
import asyncio
from mautrix.api import PathBuilder
from .config import Config
from .__meta__ import __version__
class MaubotServer:
def __init__(self, config: Config, loop: asyncio.AbstractEventLoop):
self.loop = loop or asyncio.get_event_loop()
self.app = web.Application(loop=self.loop)
self.config = config
path = PathBuilder(config["server.base_path"])
self.app.router.add_get(path.version, self.version)
as_path = PathBuilder(config["server.appservice_base_path"])
self.app.router.add_put(as_path.transactions, self.handle_transaction)
self.runner = web.AppRunner(self.app)
async def start(self) -> None:
await self.runner.setup()
site = web.TCPSite(self.runner, self.config["server.hostname"], self.config["server.port"])
await site.start()
async def stop(self) -> None:
await self.runner.cleanup()
@staticmethod
async def version(_: web.Request) -> web.Response:
return web.json_response({
"version": __version__
})
async def handle_transaction(self, request: web.Request) -> web.Response:
return web.Response(status=501)