mirror of
https://github.com/maubot/maubot.git
synced 2024-10-01 01:06:10 -04:00
Implement client API
This commit is contained in:
parent
bc87b2a02b
commit
383c9ce5ec
@ -74,18 +74,21 @@ try:
|
||||
log.info("Starting server")
|
||||
loop.run_until_complete(server.start())
|
||||
log.info("Starting clients and plugins")
|
||||
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
|
||||
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients], loop=loop))
|
||||
log.info("Startup actions complete, running forever")
|
||||
periodic_commit_task = asyncio.ensure_future(periodic_commit(), loop=loop)
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
log.debug("Interrupt received, stopping HTTP clients/servers and saving database")
|
||||
log.info("Interrupt received, stopping HTTP clients/servers and saving database")
|
||||
if periodic_commit_task is not None:
|
||||
periodic_commit_task.cancel()
|
||||
for client in Client.cache.values():
|
||||
client.stop()
|
||||
log.debug("Stopping clients")
|
||||
loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()],
|
||||
loop=loop))
|
||||
db_session.commit()
|
||||
log.debug("Stopping server")
|
||||
loop.run_until_complete(server.stop())
|
||||
log.debug("Closing event loop")
|
||||
loop.close()
|
||||
log.debug("Everything stopped, shutting down")
|
||||
sys.exit(0)
|
||||
|
@ -92,7 +92,7 @@ class Client:
|
||||
self.db_instance.enabled = False
|
||||
return
|
||||
if not self.filter_id:
|
||||
self.filter_id = await self.client.create_filter(Filter(
|
||||
self.db_instance.filter_id = await self.client.create_filter(Filter(
|
||||
room=RoomFilter(
|
||||
timeline=RoomEventFilter(
|
||||
limit=50,
|
||||
@ -122,9 +122,18 @@ class Client:
|
||||
def stop_sync(self) -> None:
|
||||
self.client.stop()
|
||||
|
||||
def stop(self) -> None:
|
||||
self.started = False
|
||||
self.stop_sync()
|
||||
async def stop(self) -> None:
|
||||
if self.started:
|
||||
self.started = False
|
||||
await self.stop_plugins()
|
||||
self.stop_sync()
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
del self.cache[self.id]
|
||||
except KeyError:
|
||||
pass
|
||||
self.db.delete(self.db_instance)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
@ -158,6 +167,44 @@ class Client:
|
||||
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
||||
await self.client.join_room(evt.room_id)
|
||||
|
||||
async def update_started(self, started: bool) -> None:
|
||||
if started is None or started == self.started:
|
||||
return
|
||||
if started:
|
||||
await self.start()
|
||||
else:
|
||||
await self.stop()
|
||||
|
||||
async def update_displayname(self, displayname: str) -> None:
|
||||
if not displayname or displayname == self.displayname:
|
||||
return
|
||||
self.db_instance.displayname = displayname
|
||||
await self.client.set_displayname(self.displayname)
|
||||
|
||||
async def update_avatar_url(self, avatar_url: ContentURI) -> None:
|
||||
if not avatar_url or avatar_url == self.avatar_url:
|
||||
return
|
||||
self.db_instance.avatar_url = avatar_url
|
||||
await self.client.set_avatar_url(self.avatar_url)
|
||||
|
||||
async def update_access_details(self, access_token: str, homeserver: str) -> None:
|
||||
if not access_token and not homeserver:
|
||||
return
|
||||
elif access_token == self.access_token and homeserver == self.homeserver:
|
||||
return
|
||||
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
|
||||
token=access_token or self.access_token, loop=self.loop,
|
||||
client_session=self.http_client, log=self.log)
|
||||
mxid = await new_client.whoami()
|
||||
if mxid != self.id:
|
||||
raise ValueError("MXID mismatch")
|
||||
new_client.store = self.db_instance
|
||||
self.stop_sync()
|
||||
self.client = new_client
|
||||
self.db_instance.homeserver = homeserver
|
||||
self.db_instance.access_token = access_token
|
||||
self.start_sync()
|
||||
|
||||
# region Properties
|
||||
|
||||
@property
|
||||
@ -172,11 +219,6 @@ class Client:
|
||||
def access_token(self) -> str:
|
||||
return self.db_instance.access_token
|
||||
|
||||
@access_token.setter
|
||||
def access_token(self, value: str) -> None:
|
||||
self.client.api.token = value
|
||||
self.db_instance.access_token = value
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.db_instance.enabled
|
||||
@ -189,25 +231,24 @@ class Client:
|
||||
def next_batch(self) -> SyncToken:
|
||||
return self.db_instance.next_batch
|
||||
|
||||
@next_batch.setter
|
||||
def next_batch(self, value: SyncToken) -> None:
|
||||
self.db_instance.next_batch = value
|
||||
|
||||
@property
|
||||
def filter_id(self) -> FilterID:
|
||||
return self.db_instance.filter_id
|
||||
|
||||
@filter_id.setter
|
||||
def filter_id(self, value: FilterID) -> None:
|
||||
self.db_instance.filter_id = value
|
||||
|
||||
@property
|
||||
def sync(self) -> bool:
|
||||
return self.db_instance.sync
|
||||
|
||||
@sync.setter
|
||||
def sync(self, value: bool) -> None:
|
||||
if value == self.db_instance.sync:
|
||||
return
|
||||
self.db_instance.sync = value
|
||||
if self.started:
|
||||
if value:
|
||||
self.start_sync()
|
||||
else:
|
||||
self.stop_sync()
|
||||
|
||||
@property
|
||||
def autojoin(self) -> bool:
|
||||
@ -227,18 +268,10 @@ class Client:
|
||||
def displayname(self) -> str:
|
||||
return self.db_instance.displayname
|
||||
|
||||
@displayname.setter
|
||||
def displayname(self, value: str) -> None:
|
||||
self.db_instance.displayname = value
|
||||
|
||||
@property
|
||||
def avatar_url(self) -> ContentURI:
|
||||
return self.db_instance.avatar_url
|
||||
|
||||
@avatar_url.setter
|
||||
def avatar_url(self, value: ContentURI) -> None:
|
||||
self.db_instance.avatar_url = value
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
|
@ -17,15 +17,20 @@ from json import JSONDecodeError
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from mautrix.types import UserID
|
||||
from mautrix.types import UserID, SyncToken, FilterID
|
||||
from mautrix.errors import MatrixRequestError, MatrixInvalidToken
|
||||
from mautrix.client import Client as MatrixClient
|
||||
|
||||
from ...db import DBClient
|
||||
from ...client import Client
|
||||
from .base import routes
|
||||
from .responses import ErrNotImplemented, ErrClientNotFound, ErrBodyNotJSON
|
||||
from .responses import (RespDeleted, ErrClientNotFound, ErrBodyNotJSON, ErrClientInUse,
|
||||
ErrBadClientAccessToken, ErrBadClientAccessDetails, ErrMXIDMismatch,
|
||||
ErrUserExists)
|
||||
|
||||
|
||||
@routes.get("/clients")
|
||||
async def get_clients(request: web.Request) -> web.Response:
|
||||
async def get_clients(_: web.Request) -> web.Response:
|
||||
return web.json_response([client.to_dict() for client in Client.cache.values()])
|
||||
|
||||
|
||||
@ -39,17 +44,59 @@ async def get_client(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
async def create_client(user_id: UserID, data: dict) -> web.Response:
|
||||
return ErrNotImplemented
|
||||
homeserver = data.get("homeserver", None)
|
||||
access_token = data.get("access_token", None)
|
||||
new_client = MatrixClient(base_url=homeserver, token=access_token, loop=Client.loop,
|
||||
client_session=Client.http_client)
|
||||
try:
|
||||
mxid = await new_client.whoami()
|
||||
except MatrixInvalidToken:
|
||||
return ErrBadClientAccessToken
|
||||
except MatrixRequestError:
|
||||
return ErrBadClientAccessDetails
|
||||
if user_id == "new":
|
||||
existing_client = Client.get(mxid, None)
|
||||
if existing_client is not None:
|
||||
return ErrUserExists
|
||||
elif mxid != user_id:
|
||||
return ErrMXIDMismatch
|
||||
db_instance = DBClient(id=user_id, homeserver=homeserver, access_token=access_token,
|
||||
enabled=data.get("enabled", True), next_batch=SyncToken(""),
|
||||
filter_id=FilterID(""), sync=data.get("sync", True),
|
||||
autojoin=data.get("autojoin", True),
|
||||
displayname=data.get("displayname", ""),
|
||||
avatar_url=data.get("avatar_url", ""))
|
||||
client = Client(db_instance)
|
||||
Client.db.add(db_instance)
|
||||
Client.db.commit()
|
||||
await client.start()
|
||||
return web.json_response(client.to_dict())
|
||||
|
||||
|
||||
async def update_client(client: Client, data: dict) -> web.Response:
|
||||
return ErrNotImplemented
|
||||
try:
|
||||
await client.update_access_details(data.get("access_token", None),
|
||||
data.get("homeserver", None))
|
||||
except MatrixInvalidToken:
|
||||
return ErrBadClientAccessToken
|
||||
except MatrixRequestError:
|
||||
return ErrBadClientAccessDetails
|
||||
except ValueError:
|
||||
return ErrMXIDMismatch
|
||||
await client.update_avatar_url(data.get("avatar_url", None))
|
||||
await client.update_displayname(data.get("displayname", None))
|
||||
await client.update_started(data.get("started", None))
|
||||
client.enabled = data.get("enabled", client.enabled)
|
||||
client.autojoin = data.get("autojoin", client.autojoin)
|
||||
client.sync = data.get("sync", client.sync)
|
||||
return web.json_response(client.to_dict())
|
||||
|
||||
|
||||
@routes.put("/client/{id}")
|
||||
async def update_client(request: web.Request) -> web.Response:
|
||||
user_id = request.match_info.get("id", None)
|
||||
client = Client.get(user_id, None)
|
||||
# /client/new always creates a new client
|
||||
client = Client.get(user_id, None) if user_id != "new" else None
|
||||
try:
|
||||
data = await request.json()
|
||||
except JSONDecodeError:
|
||||
@ -66,4 +113,9 @@ async def delete_client(request: web.Request) -> web.Response:
|
||||
client = Client.get(user_id, None)
|
||||
if not client:
|
||||
return ErrClientNotFound
|
||||
return ErrNotImplemented
|
||||
if len(client.references) > 0:
|
||||
return ErrClientInUse
|
||||
if client.started:
|
||||
await client.stop()
|
||||
client.delete()
|
||||
return RespDeleted
|
||||
|
@ -16,6 +16,36 @@
|
||||
from http import HTTPStatus
|
||||
from aiohttp import web
|
||||
|
||||
ErrBodyNotJSON = web.json_response({
|
||||
"error": "Request body is not JSON",
|
||||
"errcode": "body_not_json",
|
||||
}, status=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
ErrPluginTypeRequired = web.json_response({
|
||||
"error": "Plugin type is required when creating plugin instances",
|
||||
"errcode": "plugin_type_required",
|
||||
}, status=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
ErrPrimaryUserRequired = web.json_response({
|
||||
"error": "Primary user is required when creating plugin instances",
|
||||
"errcode": "primary_user_required",
|
||||
}, status=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
ErrBadClientAccessToken = web.json_response({
|
||||
"error": "Invalid access token",
|
||||
"errcode": "bad_client_access_token",
|
||||
}, status=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
ErrBadClientAccessDetails = web.json_response({
|
||||
"error": "Invalid homeserver or access token",
|
||||
"errcode": "bad_client_access_details"
|
||||
}, status=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
ErrMXIDMismatch = web.json_response({
|
||||
"error": "The Matrix user ID of the client and the user ID of the access token don't match",
|
||||
"errcode": "mxid_mismatch",
|
||||
}, status=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
ErrBadAuth = web.json_response({
|
||||
"error": "Invalid username or password",
|
||||
"errcode": "invalid_auth",
|
||||
@ -56,16 +86,6 @@ ErrPluginTypeNotFound = web.json_response({
|
||||
"errcode": "plugin_type_not_found",
|
||||
}, status=HTTPStatus.NOT_FOUND)
|
||||
|
||||
ErrPluginTypeRequired = web.json_response({
|
||||
"error": "Plugin type is required when creating plugin instances",
|
||||
"errcode": "plugin_type_required",
|
||||
}, status=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
ErrPrimaryUserRequired = web.json_response({
|
||||
"error": "Primary user is required when creating plugin instances",
|
||||
"errcode": "primary_user_required",
|
||||
}, status=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
ErrPathNotFound = web.json_response({
|
||||
"error": "Resource not found",
|
||||
"errcode": "resource_not_found",
|
||||
@ -76,15 +96,20 @@ ErrMethodNotAllowed = web.json_response({
|
||||
"errcode": "method_not_allowed",
|
||||
}, status=HTTPStatus.METHOD_NOT_ALLOWED)
|
||||
|
||||
ErrUserExists = web.json_response({
|
||||
"error": "There is already a client with the user ID of that token",
|
||||
"errcode": "user_exists",
|
||||
}, status=HTTPStatus.CONFLICT)
|
||||
|
||||
ErrPluginInUse = web.json_response({
|
||||
"error": "Plugin instances of this type still exist",
|
||||
"errcode": "plugin_in_use",
|
||||
}, status=HTTPStatus.PRECONDITION_FAILED)
|
||||
|
||||
ErrBodyNotJSON = web.json_response({
|
||||
"error": "Request body is not JSON",
|
||||
"errcode": "body_not_json",
|
||||
}, status=HTTPStatus.BAD_REQUEST)
|
||||
ErrClientInUse = web.json_response({
|
||||
"error": "Plugin instances with this client as their primary user still exist",
|
||||
"errcode": "client_in_use",
|
||||
}, status=HTTPStatus.PRECONDITION_FAILED)
|
||||
|
||||
|
||||
def plugin_import_error(error: str, stacktrace: str) -> web.Response:
|
||||
|
Loading…
Reference in New Issue
Block a user