diff --git a/maubot/instance.py b/maubot/instance.py index 8d9bd6e..b9b1c23 100644 --- a/maubot/instance.py +++ b/maubot/instance.py @@ -28,7 +28,7 @@ from ruamel.yaml.comments import CommentedMap import sqlalchemy as sql from mautrix.types import UserID -from mautrix.util.async_db import Database, SQLiteDatabase, UpgradeTable +from mautrix.util.async_db import Database, Scheme, UpgradeTable from mautrix.util.async_getter_lock import async_getter_lock from mautrix.util.config import BaseProxyConfig, RecursiveDict from mautrix.util.logging import TraceLogger @@ -65,7 +65,7 @@ class PluginInstance(DBInstance): base_cfg: RecursiveDict[CommentedMap] | None base_cfg_str: str | None inst_db: sql.engine.Engine | Database | None - inst_db_tables: dict[str, sql.Table] | None + inst_db_tables: dict | None inst_webapp: PluginWebApp | None inst_webapp_url: str | None started: bool @@ -113,11 +113,99 @@ class PluginInstance(DBInstance): ), } - def get_db_tables(self) -> dict[str, sql.Table]: - if not self.inst_db_tables: - metadata = sql.MetaData() - metadata.reflect(self.inst_db) - self.inst_db_tables = metadata.tables + def _introspect_sqlalchemy(self) -> dict: + metadata = sql.MetaData() + metadata.reflect(self.inst_db) + return { + table.name: { + "columns": { + column.name: { + "type": str(column.type), + "unique": column.unique or False, + "default": column.default, + "nullable": column.nullable, + "primary": column.primary_key, + } + for column in table.columns + }, + } + for table in metadata.tables.values() + } + + async def _introspect_sqlite(self) -> dict: + q = """ + SELECT + m.name AS table_name, + p.cid AS col_id, + p.name AS column_name, + p.type AS data_type, + p.pk AS is_primary, + p.dflt_value AS column_default, + p.[notnull] AS is_nullable + FROM sqlite_master m + LEFT JOIN pragma_table_info((m.name)) p + WHERE m.type = 'table' + ORDER BY table_name, col_id + """ + data = await self.inst_db.fetch(q) + tables = defaultdict(lambda: {"columns": {}}) + for column in data: + table_name = column["table_name"] + col_name = column["column_name"] + tables[table_name]["columns"][col_name] = { + "type": column["data_type"], + "nullable": bool(column["is_nullable"]), + "default": column["column_default"], + "primary": bool(column["is_primary"]), + # TODO uniqueness? + } + return tables + + async def _introspect_postgres(self) -> dict: + assert isinstance(self.inst_db, ProxyPostgresDatabase) + q = """ + SELECT col.table_name, col.column_name, col.data_type, col.is_nullable, col.column_default, + tc.constraint_type + FROM information_schema.columns col + LEFT JOIN information_schema.constraint_column_usage ccu + ON ccu.column_name=col.column_name + LEFT JOIN information_schema.table_constraints tc + ON col.table_name=tc.table_name + AND col.table_schema=tc.table_schema + AND ccu.constraint_name=tc.constraint_name + AND ccu.constraint_schema=tc.constraint_schema + AND tc.constraint_type IN ('PRIMARY KEY', 'UNIQUE') + WHERE col.table_schema=$1 + """ + data = await self.inst_db.fetch(q, self.inst_db.schema_name) + tables = defaultdict(lambda: {"columns": {}}) + for column in data: + table_name = column["table_name"] + col_name = column["column_name"] + tables[table_name]["columns"].setdefault( + col_name, + { + "type": column["data_type"], + "nullable": column["is_nullable"], + "default": column["column_default"], + "primary": False, + "unique": False, + }, + ) + if column["constraint_type"] == "PRIMARY KEY": + tables[table_name]["columns"][col_name]["primary"] = True + elif column["constraint_type"] == "UNIQUE": + tables[table_name]["columns"][col_name]["unique"] = True + return tables + + async def get_db_tables(self) -> dict: + if self.inst_db_tables is None: + if isinstance(self.inst_db, sql.engine.Engine): + self.inst_db_tables = self._introspect_sqlalchemy() + elif self.inst_db.scheme == Scheme.SQLITE: + self.inst_db_tables = await self._introspect_sqlite() + else: + self.inst_db_tables = await self._introspect_postgres() return self.inst_db_tables async def load(self) -> bool: diff --git a/maubot/lib/plugin_db.py b/maubot/lib/plugin_db.py index a99d461..af2cc55 100644 --- a/maubot/lib/plugin_db.py +++ b/maubot/lib/plugin_db.py @@ -28,7 +28,8 @@ remove_double_quotes = str.maketrans({'"': "_"}) class ProxyPostgresDatabase(Database): scheme = Scheme.POSTGRES _underlying_pool: PostgresDatabase - _schema: str + schema_name: str + _quoted_schema: str _default_search_path: str _conn_sema: asyncio.Semaphore @@ -44,7 +45,8 @@ class ProxyPostgresDatabase(Database): self._underlying_pool = pool # Simple accidental SQL injection prevention. # Doesn't have to be perfect, since plugin instance IDs can only be set by admins anyway. - self._schema = f'"mbp_{instance_id.translate(remove_double_quotes)}"' + self.schema_name = f"mbp_{instance_id.translate(remove_double_quotes)}" + self._quoted_schema = f'"{self.schema_name}"' self._default_search_path = '"$user", public' self._conn_sema = asyncio.BoundedSemaphore(max_conns) @@ -52,7 +54,7 @@ class ProxyPostgresDatabase(Database): async with self._underlying_pool.acquire() as conn: self._default_search_path = await conn.fetchval("SHOW search_path") self.log.debug(f"Found default search path: {self._default_search_path}") - await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self._schema}") + await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self._quoted_schema}") await super().start() async def stop(self) -> None: @@ -67,9 +69,11 @@ class ProxyPostgresDatabase(Database): break async def delete(self) -> None: - self.log.debug(f"Deleting schema {self._schema} and all data in it") + self.log.debug(f"Deleting schema {self._quoted_schema} and all data in it") try: - await self._underlying_pool.execute(f"DROP SCHEMA IF EXISTS {self._schema} CASCADE") + await self._underlying_pool.execute( + f"DROP SCHEMA IF EXISTS {self._quoted_schema} CASCADE" + ) except Exception: self.log.warning("Failed to delete schema", exc_info=True) @@ -77,7 +81,7 @@ class ProxyPostgresDatabase(Database): async def acquire(self) -> LoggingConnection: conn: LoggingConnection async with self._conn_sema, self._underlying_pool.acquire() as conn: - await conn.execute(f"SET search_path = {self._default_search_path}") + await conn.execute(f"SET search_path = {self._quoted_schema}") try: yield conn finally: diff --git a/maubot/management/api/instance_database.py b/maubot/management/api/instance_database.py index 25869ce..d271031 100644 --- a/maubot/management/api/instance_database.py +++ b/maubot/management/api/instance_database.py @@ -18,9 +18,11 @@ from __future__ import annotations from datetime import datetime from aiohttp import web -from sqlalchemy import Column, Table, asc, desc, exc +from sqlalchemy import asc, desc, engine, exc from sqlalchemy.engine.result import ResultProxy, RowProxy -from sqlalchemy.orm import Query +import aiosqlite + +from mautrix.util.async_db import Database from ...instance import PluginInstance from .base import routes @@ -35,32 +37,7 @@ async def get_database(request: web.Request) -> web.Response: return resp.instance_not_found elif not instance.inst_db: return resp.plugin_has_no_database - table: Table - column: Column - return web.json_response( - { - table.name: { - "columns": { - column.name: { - "type": str(column.type), - "unique": column.unique or False, - "default": column.default, - "nullable": column.nullable, - "primary": column.primary_key, - "autoincrement": column.autoincrement, - } - for column in table.columns - }, - } - for table in instance.get_db_tables().values() - } - ) - - -def check_type(val): - if isinstance(val, datetime): - return val.isoformat() - return val + return web.json_response(await instance.get_db_tables()) @routes.get("/instance/{id}/database/{table}") @@ -71,7 +48,7 @@ async def get_table(request: web.Request) -> web.Response: return resp.instance_not_found elif not instance.inst_db: return resp.plugin_has_no_database - tables = instance.get_db_tables() + tables = await instance.get_db_tables() try: table = tables[request.match_info.get("table", "")] except KeyError: @@ -87,7 +64,8 @@ async def get_table(request: web.Request) -> web.Response: except KeyError: order = [] limit = int(request.query.get("limit", "100")) - return execute_query(instance, table.select().order_by(*order).limit(limit)) + if isinstance(instance.inst_db, engine.Engine): + return _execute_query_sqlalchemy(instance, table.select().order_by(*order).limit(limit)) @routes.post("/instance/{id}/database/query") @@ -103,12 +81,54 @@ async def query(request: web.Request) -> web.Response: sql_query = data["query"] except KeyError: return resp.query_missing - return execute_query(instance, sql_query, rows_as_dict=data.get("rows_as_dict", False)) + rows_as_dict = data.get("rows_as_dict", False) + if isinstance(instance.inst_db, engine.Engine): + return _execute_query_sqlalchemy(instance, sql_query, rows_as_dict) + elif isinstance(instance.inst_db, Database): + return await _execute_query_asyncpg(instance, sql_query, rows_as_dict) + else: + return resp.unsupported_plugin_database -def execute_query( - instance: PluginInstance, sql_query: str | Query, rows_as_dict: bool = False +def check_type(val): + if isinstance(val, datetime): + return val.isoformat() + return val + + +async def _execute_query_asyncpg( + instance: PluginInstance, sql_query: str, rows_as_dict: bool = False ) -> web.Response: + data = {"ok": True, "query": sql_query} + if sql_query.upper().startswith("SELECT"): + res = await instance.inst_db.fetch(sql_query) + data["rows"] = [ + ( + {key: check_type(value) for key, value in row.items()} + if rows_as_dict + else [check_type(value) for value in row] + ) + for row in res + ] + if len(res) > 0: + # TODO can we find column names when there are no rows? + data["columns"] = list(res[0].keys()) + else: + res = await instance.inst_db.execute(sql_query) + if isinstance(res, str): + data["status_msg"] = res + elif isinstance(res, aiosqlite.Cursor): + data["rowcount"] = res.rowcount + # data["inserted_primary_key"] = res.lastrowid + else: + data["status_msg"] = "unknown status" + return web.json_response(data) + + +def _execute_query_sqlalchemy( + instance: PluginInstance, sql_query: str, rows_as_dict: bool = False +) -> web.Response: + assert isinstance(instance.inst_db, engine.Engine) try: res: ResultProxy = instance.inst_db.execute(sql_query) except exc.IntegrityError as e: diff --git a/maubot/management/api/responses.py b/maubot/management/api/responses.py index 8e07abb..b645def 100644 --- a/maubot/management/api/responses.py +++ b/maubot/management/api/responses.py @@ -299,6 +299,15 @@ class _Response: } ) + @property + def unsupported_plugin_database(self) -> web.Response: + return web.json_response( + { + "error": "The database type is not supported by this API", + "errcode": "unsupported_plugin_database", + } + ) + @property def table_not_found(self) -> web.Response: return web.json_response( diff --git a/maubot/management/frontend/src/pages/dashboard/InstanceDatabase.js b/maubot/management/frontend/src/pages/dashboard/InstanceDatabase.js index 36435e9..2f3f525 100644 --- a/maubot/management/frontend/src/pages/dashboard/InstanceDatabase.js +++ b/maubot/management/frontend/src/pages/dashboard/InstanceDatabase.js @@ -44,6 +44,7 @@ class InstanceDatabase extends Component { error: null, prevQuery: null, + statusMsg: null, rowCount: null, insertedPrimaryKey: null, } @@ -111,6 +112,7 @@ class InstanceDatabase extends Component { prevQuery: null, rowCount: null, insertedPrimaryKey: null, + statusMsg: null, error: null, }) } @@ -127,7 +129,8 @@ class InstanceDatabase extends Component { this.setState({ prevQuery: res.query, rowCount: res.rowcount, - insertedPrimaryKey: res.insertedPrimaryKey, + insertedPrimaryKey: res.inserted_primary_key, + statusMsg: res.status_msg, }) this.buildSQLQuery(this.state.selectedTable, false) } @@ -298,8 +301,10 @@ class InstanceDatabase extends Component { } {this.state.prevQuery &&

- Executed {this.state.prevQuery} - - affected {this.state.rowCount} rows. + Executed {this.state.prevQuery} - { + this.state.statusMsg + || <>affected {this.state.rowCount} rows. + }

{this.state.insertedPrimaryKey &&

Inserted primary key: {this.state.insertedPrimaryKey}