diff --git a/maubot/cli/cliq/cliq.py b/maubot/cli/cliq/cliq.py index 6a10e28..f8992b8 100644 --- a/maubot/cli/cliq/cliq.py +++ b/maubot/cli/cliq/cliq.py @@ -15,15 +15,41 @@ # along with this program. If not, see . from typing import Any, Callable, Union, Optional import functools +import inspect +import asyncio + +import aiohttp from prompt_toolkit.validation import Validator from questionary import prompt import click from ..base import app +from ..config import get_token from .validators import Required, ClickValidator +def with_http(func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + async with aiohttp.ClientSession() as sess: + return await func(*args, sess=sess, **kwargs) + + return wrapper + + +def with_authenticated_http(func): + @functools.wraps(func) + async def wrapper(*args, server: str, **kwargs): + server, token = get_token(server) + if not token: + return + async with aiohttp.ClientSession(headers={"Authorization": f"Bearer {token}"}) as sess: + return await func(*args, sess=sess, server=server, **kwargs) + + return wrapper + + def command(help: str) -> Callable[[Callable], Callable]: def decorator(func) -> Callable: questions = func.__inquirer_questions__.copy() @@ -52,7 +78,10 @@ def command(help: str) -> Callable[[Callable], Callable]: if not resp and question_list: return kwargs = {**kwargs, **resp} - func(*args, **kwargs) + + res = func(*args, **kwargs) + if inspect.isawaitable(res): + asyncio.run(res) return app.command(help=help)(wrapper) diff --git a/maubot/cli/commands/auth.py b/maubot/cli/commands/auth.py index c537cf1..66c8b6e 100644 --- a/maubot/cli/commands/auth.py +++ b/maubot/cli/commands/auth.py @@ -13,13 +13,11 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from urllib.parse import quote -from urllib.request import urlopen, Request -from urllib.error import HTTPError -import functools import json from colorama import Fore +from yarl import URL +import aiohttp import click from ..config import get_token @@ -27,8 +25,6 @@ from ..cliq import cliq history_count: int = 10 -enc = functools.partial(quote, safe="") - friendly_errors = { "server_not_found": "Registration target server not found.\n\n" "To log in or register through maubot, you must add the server to the\n" @@ -37,6 +33,15 @@ friendly_errors = { } +async def list_servers(server: str, sess: aiohttp.ClientSession) -> None: + url = URL(server) / "_matrix/maubot/v1/client/auth/servers" + async with sess.get(url) as resp: + data = await resp.json() + print(f"{Fore.GREEN}Available Matrix servers for registration and login:{Fore.RESET}") + for server in data.keys(): + print(f"* {Fore.CYAN}{server}{Fore.RESET}") + + @cliq.command(help="Log into a Matrix account via the Maubot server") @cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list") @cliq.option("-u", "--username", help="The username to log in with", required_unless="list") @@ -46,42 +51,40 @@ friendly_errors = { required=False, prompt=False) @click.option("-r", "--register", help="Register instead of logging in", is_flag=True, default=False) +@click.option("-c", "--update-client", help="Instead of returning the access token, " + "create or update a client in maubot using it", + is_flag=True, default=False) @click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False) -def auth(homeserver: str, username: str, password: str, server: str, register: bool, list: bool - ) -> None: - server, token = get_token(server) - if not token: - return - headers = {"Authorization": f"Bearer {token}"} +@cliq.with_authenticated_http +async def auth(homeserver: str, username: str, password: str, server: str, register: bool, + list: bool, update_client: bool, sess: aiohttp.ClientSession) -> None: if list: - url = f"{server}/_matrix/maubot/v1/client/auth/servers" - with urlopen(Request(url, headers=headers)) as resp_data: - resp = json.load(resp_data) - print(f"{Fore.GREEN}Available Matrix servers for registration and login:{Fore.RESET}") - for server in resp.keys(): - print(f"* {Fore.CYAN}{server}{Fore.RESET}") - return + await list_servers(server, sess) + return endpoint = "register" if register else "login" - headers["Content-Type"] = "application/json" - url = f"{server}/_matrix/maubot/v1/client/auth/{enc(homeserver)}/{endpoint}" - req = Request(url, headers=headers, - data=json.dumps({ - "username": username, - "password": password, - }).encode("utf-8")) - try: - with urlopen(req) as resp_data: - resp = json.load(resp_data) + url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / endpoint + if update_client: + url = url.with_query({"update_client": "true"}) + req_data = {"username": username, "password": password} + + async with sess.post(url, json=req_data) as resp: + if resp.status == 200: + data = await resp.json() action = "registered" if register else "logged in as" - print(f"{Fore.GREEN}Successfully {action} " - f"{Fore.CYAN}{resp['user_id']}{Fore.GREEN}.") - print(f"{Fore.GREEN}Access token: {Fore.CYAN}{resp['access_token']}{Fore.RESET}") - print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{resp['device_id']}{Fore.RESET}") - except HTTPError as e: - try: - err_data = json.load(e) - error = friendly_errors.get(err_data["errcode"], err_data["error"]) - except (json.JSONDecodeError, KeyError): - error = str(e) - action = "register" if register else "log in" - print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}") + print(f"{Fore.GREEN}Successfully {action} {Fore.CYAN}{data['user_id']}{Fore.GREEN}.") + print(f"{Fore.GREEN}Access token: {Fore.CYAN}{data['access_token']}{Fore.RESET}") + print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{data['device_id']}{Fore.RESET}") + elif resp.status in (201, 202): + data = await resp.json() + action = "created" if resp.status == 201 else "updated" + print(f"{Fore.GREEN}Successfully {action} client for " + f"{Fore.CYAN}{data['id']}{Fore.GREEN} / " + f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}") + else: + try: + err_data = await resp.json() + error = friendly_errors.get(err_data["errcode"], err_data["error"]) + except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError): + error = await resp.text() + action = "register" if register else "log in" + print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}") diff --git a/maubot/client.py b/maubot/client.py index 60fd335..7d0a6de 100644 --- a/maubot/client.py +++ b/maubot/client.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -86,10 +86,8 @@ class Client: log=self.log, loop=self.loop, device_id=self.device_id, sync_store=SyncStoreProxy(self.db_instance), state_store=self.global_state_store) - if OlmMachine and self.device_id and self.crypto_db: - self.crypto_store = self._make_crypto_store() - self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store) - self.client.crypto = self.crypto + if self.enable_crypto: + self._prepare_crypto() else: self.crypto_store = None self.crypto = None @@ -102,10 +100,15 @@ class Client: self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False)) self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True)) - def _make_crypto_store(self) -> 'CryptoStore': - if self.crypto_db: - return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db) - raise ValueError("Crypto database not configured") + @property + def enable_crypto(self) -> bool: + return bool(OlmMachine and self.device_id and self.crypto_db) + + def _prepare_crypto(self) -> None: + self.crypto_store = PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", + db=self.crypto_db) + self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store) + self.client.crypto = self.crypto def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]: async def handler(data: Dict[str, Any]) -> None: @@ -121,6 +124,19 @@ class Client: except Exception: self.log.exception("Failed to start") + async def _start_crypto(self) -> None: + self.log.debug("Enabling end-to-end encryption support") + await self.crypto_store.open() + crypto_device_id = await self.crypto_store.get_device_id() + if crypto_device_id and crypto_device_id != self.device_id: + self.log.warning("Mismatching device ID in crypto store and main database, " + "resetting encryption") + await self.crypto_store.delete() + crypto_device_id = None + await self.crypto.load() + if not crypto_device_id: + await self.crypto_store.put_device_id(self.device_id) + async def _start(self, try_n: Optional[int] = 0) -> None: if not self.enabled: self.log.debug("Not starting disabled client") @@ -129,7 +145,7 @@ class Client: self.log.warning("Ignoring start() call to started client") return try: - user_id = await self.client.whoami() + whoami = await self.client.whoami() except MatrixInvalidToken as e: self.log.error(f"Invalid token: {e}. Disabling client") self.db_instance.enabled = False @@ -143,8 +159,13 @@ class Client: f"retrying in {(try_n + 1) * 10}s: {e}") _ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop) return - if user_id != self.id: - self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}") + if whoami.user_id != self.id: + self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}") + self.db_instance.enabled = False + return + elif whoami.device_id and self.device_id and whoami.device_id != self.device_id: + self.log.error(f"Device ID mismatch: expected {self.device_id}, " + f"but got {whoami.device_id}") self.db_instance.enabled = False return if not self.filter_id: @@ -167,15 +188,7 @@ class Client: if self.avatar_url != "disable": await self.client.set_avatar_url(self.avatar_url) if self.crypto: - self.log.debug("Enabling end-to-end encryption support") - await self.crypto_store.open() - crypto_device_id = await self.crypto_store.get_device_id() - if crypto_device_id and crypto_device_id != self.device_id: - self.log.warning("Mismatching device ID in crypto store and main database. " - "Encryption may not work.") - await self.crypto.load() - if not crypto_device_id: - await self.crypto_store.put_device_id(self.device_id) + await self._start_crypto() self.start_sync() await self._update_remote_profile() self.started = True @@ -285,23 +298,31 @@ class Client: else: await self._update_remote_profile() - async def update_access_details(self, access_token: str, homeserver: str) -> None: + async def update_access_details(self, access_token: str, homeserver: str, + device_id: Optional[str] = None) -> None: if not access_token and not homeserver: return elif access_token == self.access_token and homeserver == self.homeserver: return + device_id = device_id or self.device_id new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver, token=access_token or self.access_token, loop=self.loop, - client_session=self.http_client, device_id=self.device_id, + device_id=device_id, client_session=self.http_client, log=self.log, state_store=self.global_state_store) - mxid = await new_client.whoami() - if mxid != self.id: - raise ValueError(f"MXID mismatch: {mxid}") + whoami = await new_client.whoami() + if whoami.user_id != self.id: + raise ValueError(f"MXID mismatch: {whoami.user_id}") + elif whoami.device_id and device_id and whoami.device_id != device_id: + raise ValueError(f"Device ID mismatch: {whoami.device_id}") new_client.sync_store = SyncStoreProxy(self.db_instance) self.stop_sync() self.client = new_client self.db_instance.homeserver = homeserver self.db_instance.access_token = access_token + self.db_instance.device_id = device_id + if self.enable_crypto: + self._prepare_crypto() + await self._start_crypto() self.start_sync() async def _update_remote_profile(self) -> None: diff --git a/maubot/management/api/client.py b/maubot/management/api/client.py index 0585d63..27f2150 100644 --- a/maubot/management/api/client.py +++ b/maubot/management/api/client.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -45,10 +45,11 @@ async def get_client(request: web.Request) -> web.Response: async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response: homeserver = data.get("homeserver", None) access_token = data.get("access_token", None) + device_id = data.get("device_id", None) new_client = MatrixClient(mxid="@not:a.mxid", base_url=homeserver, token=access_token, loop=Client.loop, client_session=Client.http_client) try: - mxid = await new_client.whoami() + whoami = await new_client.whoami() except MatrixInvalidToken: return resp.bad_client_access_token except MatrixRequestError: @@ -56,27 +57,31 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response: except MatrixConnectionError: return resp.bad_client_connection_details if user_id is None: - existing_client = Client.get(mxid, None) + existing_client = Client.get(whoami.user_id, None) if existing_client is not None: return resp.user_exists - elif mxid != user_id: - return resp.mxid_mismatch(mxid) - db_instance = DBClient(id=mxid, homeserver=homeserver, access_token=access_token, + elif whoami.user_id != user_id: + return resp.mxid_mismatch(whoami.user_id) + elif whoami.device_id and device_id and whoami.device_id != device_id: + return resp.device_id_mismatch(whoami.device_id) + db_instance = DBClient(id=whoami.user_id, homeserver=homeserver, access_token=access_token, enabled=data.get("enabled", True), next_batch=SyncToken(""), filter_id=FilterID(""), sync=data.get("sync", True), autojoin=data.get("autojoin", True), online=data.get("online", True), displayname=data.get("displayname", ""), - avatar_url=data.get("avatar_url", "")) + avatar_url=data.get("avatar_url", ""), + device_id=device_id) client = Client(db_instance) client.db_instance.insert() await client.start() return resp.created(client.to_dict()) -async def _update_client(client: Client, data: dict) -> web.Response: +async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response: try: await client.update_access_details(data.get("access_token", None), - data.get("homeserver", None)) + data.get("homeserver", None), + data.get("device_id", None)) except MatrixInvalidToken: return resp.bad_client_access_token except MatrixRequestError: @@ -93,7 +98,16 @@ async def _update_client(client: Client, data: dict) -> web.Response: client.autojoin = data.get("autojoin", client.autojoin) client.online = data.get("online", client.online) client.sync = data.get("sync", client.sync) - return resp.updated(client.to_dict()) + return resp.updated(client.to_dict(), is_login=is_login) + + +async def _create_or_update_client(user_id: UserID, data: dict, is_login: bool = False + ) -> web.Response: + client = Client.get(user_id, None) + if not client: + return await _create_client(user_id, data) + else: + return await _update_client(client, data, is_login=is_login) @routes.post("/client/new") @@ -108,15 +122,11 @@ async def create_client(request: web.Request) -> web.Response: @routes.put("/client/{id}") async def update_client(request: web.Request) -> web.Response: user_id = request.match_info.get("id", None) - client = Client.get(user_id, None) try: data = await request.json() except JSONDecodeError: return resp.body_not_json - if not client: - return await _create_client(user_id, data) - else: - return await _update_client(client, data) + return await _create_or_update_client(user_id, data) @routes.delete("/client/{id}") diff --git a/maubot/management/api/client_auth.py b/maubot/management/api/client_auth.py index a3a5878..abd5246 100644 --- a/maubot/management/api/client_auth.py +++ b/maubot/management/api/client_auth.py @@ -25,10 +25,11 @@ from aiohttp import web from mautrix.api import SynapseAdminPath, Method from mautrix.errors import MatrixRequestError from mautrix.client import ClientAPI -from mautrix.types import LoginType +from mautrix.types import LoginType, LoginResponse from .base import routes, get_config, get_loop from .responses import resp +from .client import _create_or_update_client, _create_client def known_homeservers() -> Dict[str, Dict[str, str]]: @@ -46,6 +47,7 @@ class AuthRequestInfo(NamedTuple): username: str password: str user_type: str + update_client: bool async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo], @@ -70,15 +72,16 @@ async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthR secret = server.get("secret") api = ClientAPI(base_url=base_url, loop=get_loop()) user_type = body.get("user_type", "bot") - return AuthRequestInfo(api, secret, username, password, user_type), None + update_client = request.query.get("update_client", "").lower() in ("1", "true", "yes") + return AuthRequestInfo(api, secret, username, password, user_type, update_client), None -def generate_mac(secret: str, nonce: str, user: str, password: str, admin: bool = False, +def generate_mac(secret: str, nonce: str, username: str, password: str, admin: bool = False, user_type: str = None) -> str: mac = hmac.new(key=secret.encode("utf-8"), digestmod=hashlib.sha1) mac.update(nonce.encode("utf-8")) mac.update(b"\x00") - mac.update(user.encode("utf-8")) + mac.update(username.encode("utf-8")) mac.update(b"\x00") mac.update(password.encode("utf-8")) mac.update(b"\x00") @@ -94,28 +97,34 @@ async def register(request: web.Request) -> web.Response: info, err = await read_client_auth_request(request) if err is not None: return err - client: ClientAPI - client, secret, username, password, user_type = info - if not secret: + if not info.secret: return resp.registration_secret_not_found path = SynapseAdminPath.v1.register - res = await client.api.request(Method.GET, path) + res = await info.client.api.request(Method.GET, path) content = { "nonce": res["nonce"], - "username": username, - "password": password, + "username": info.username, + "password": info.password, "admin": False, - "mac": generate_mac(secret, res["nonce"], username, password, user_type=user_type), - "user_type": user_type, + "user_type": info.user_type, } + content["mac"] = generate_mac(**content, secret=info.secret) try: - return web.json_response(await client.api.request(Method.POST, path, content=content)) + raw_res = await info.client.api.request(Method.POST, path, content=content) except MatrixRequestError as e: return web.json_response({ "errcode": e.errcode, "error": e.message, "http_status": e.http_status, }, status=HTTPStatus.INTERNAL_SERVER_ERROR) + login_res = LoginResponse.deserialize(raw_res) + if info.update_client: + return await _create_client(login_res.user_id, { + "homeserver": str(info.client.api.base_url), + "access_token": login_res.access_token, + "device_id": login_res.device_id, + }) + return web.json_response(login_res.serialize()) @routes.post("/client/auth/{server}/login") @@ -129,9 +138,15 @@ async def login(request: web.Request) -> web.Response: res = await client.login(identifier=info.username, login_type=LoginType.PASSWORD, password=info.password, device_id=f"maubot_{device_id}", initial_device_display_name="Maubot", store_access_token=False) - return web.json_response(res.serialize()) except MatrixRequestError as e: return web.json_response({ "errcode": e.errcode, "error": e.message, }, status=e.http_status) + if info.update_client: + return await _create_or_update_client(res.user_id, { + "homeserver": str(client.api.base_url), + "access_token": res.access_token, + "device_id": res.device_id, + }, is_login=True) + return web.json_response(res.serialize()) diff --git a/maubot/management/api/responses.py b/maubot/management/api/responses.py index a4e76d8..d6dee04 100644 --- a/maubot/management/api/responses.py +++ b/maubot/management/api/responses.py @@ -69,6 +69,13 @@ class _Response: "errcode": "mxid_mismatch", }, status=HTTPStatus.BAD_REQUEST) + def device_id_mismatch(self, found: str) -> web.Response: + return web.json_response({ + "error": "The Matrix device ID of the client and the device ID of the access token " + f"don't match. Access token is for device {found}", + "errcode": "mxid_mismatch", + }, status=HTTPStatus.BAD_REQUEST) + @property def pid_mismatch(self) -> web.Response: return web.json_response({ @@ -294,8 +301,9 @@ class _Response: def found(data: dict) -> web.Response: return web.json_response(data, status=HTTPStatus.OK) - def updated(self, data: dict) -> web.Response: - return self.found(data) + @staticmethod + def updated(data: dict, is_login: bool = False) -> web.Response: + return web.json_response(data, status=HTTPStatus.ACCEPTED if is_login else HTTPStatus.OK) def logged_in(self, token: str) -> web.Response: return self.found({ diff --git a/maubot/management/api/spec.yaml b/maubot/management/api/spec.yaml index c6f5181..a86358e 100644 --- a/maubot/management/api/spec.yaml +++ b/maubot/management/api/spec.yaml @@ -366,7 +366,7 @@ paths: schema: $ref: '#/components/schemas/MatrixClient' responses: - 200: + 202: description: Client updated content: application/json: @@ -454,6 +454,12 @@ paths: required: true schema: type: string + - name: update_client + in: query + description: Should maubot store the access details in a Client instead of returning them? + required: false + schema: + type: boolean post: operationId: client_auth_register summary: | @@ -475,18 +481,29 @@ paths: properties: access_token: type: string - example: token_here + example: syt_123_456_789 user_id: type: string example: '@putkiteippi:maunium.net' - home_server: - type: string - example: maunium.net device_id: type: string - example: device_id_here + example: maubot_F00BAR12 + 201: + description: Client created (when update_client is true) + content: + application/json: + schema: + $ref: '#/components/schemas/MatrixClient' 401: $ref: '#/components/responses/Unauthorized' + 409: + description: | + There is already a client with the user ID of that token. + This should usually not happen, because the user ID was just created. + content: + application/json: + schema: + $ref: '#/components/schemas/Error' 500: $ref: '#/components/responses/MatrixServerError' '/client/auth/{server}/login': @@ -497,6 +514,12 @@ paths: required: true schema: type: string + - name: update_client + in: query + description: Should maubot store the access details in a Client instead of returning them? + required: false + schema: + type: boolean post: operationId: client_auth_login summary: Log in to the given Matrix server via the maubot server @@ -519,10 +542,22 @@ paths: example: '@putkiteippi:maunium.net' access_token: type: string - example: token_here + example: syt_123_456_789 device_id: type: string - example: device_id_here + example: maubot_F00BAR12 + 201: + description: Client created (when update_client is true) + content: + application/json: + schema: + $ref: '#/components/schemas/MatrixClient' + 202: + description: Client updated (when update_client is true) + content: + application/json: + schema: + $ref: '#/components/schemas/MatrixClient' 401: $ref: '#/components/responses/Unauthorized' 500: @@ -641,6 +676,9 @@ components: access_token: type: string description: The Matrix access token for this client. + device_id: + type: string + description: The Matrix device ID corresponding to the access token. enabled: type: boolean example: true diff --git a/maubot/standalone/__main__.py b/maubot/standalone/__main__.py index 97ce5cd..ffd1a79 100644 --- a/maubot/standalone/__main__.py +++ b/maubot/standalone/__main__.py @@ -144,13 +144,13 @@ async def main(): while True: try: - whoami_user_id = await client.whoami() + whoami = await client.whoami() except Exception: log.exception("Failed to connect to homeserver, retrying in 10 seconds...") await asyncio.sleep(10) continue - if whoami_user_id != user_id: - log.fatal(f"User ID mismatch: configured {user_id}, but server said {whoami_user_id}") + if whoami.user_id != user_id: + log.fatal(f"User ID mismatch: configured {user_id}, but server said {whoami.user_id}") sys.exit(1) break