mirror of
https://github.com/maubot/maubot.git
synced 2024-10-01 01:06:10 -04:00
parent
861d81d2a6
commit
09a0efbf19
@ -25,7 +25,6 @@ import os.path
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
from ruamel.yaml.comments import CommentedMap
|
||||
import sqlalchemy as sql
|
||||
|
||||
from mautrix.types import UserID
|
||||
from mautrix.util import background_task
|
||||
@ -36,6 +35,7 @@ from mautrix.util.logging import TraceLogger
|
||||
|
||||
from .client import Client
|
||||
from .db import DatabaseEngine, Instance as DBInstance
|
||||
from .lib.optionalalchemy import Engine, MetaData, create_engine
|
||||
from .lib.plugin_db import ProxyPostgresDatabase
|
||||
from .loader import DatabaseType, PluginLoader, ZippedPluginLoader
|
||||
from .plugin_base import Plugin
|
||||
@ -128,7 +128,7 @@ class PluginInstance(DBInstance):
|
||||
}
|
||||
|
||||
def _introspect_sqlalchemy(self) -> dict:
|
||||
metadata = sql.MetaData()
|
||||
metadata = MetaData()
|
||||
metadata.reflect(self.inst_db)
|
||||
return {
|
||||
table.name: {
|
||||
@ -214,7 +214,7 @@ class PluginInstance(DBInstance):
|
||||
|
||||
async def get_db_tables(self) -> dict:
|
||||
if self.inst_db_tables is None:
|
||||
if isinstance(self.inst_db, sql.engine.Engine):
|
||||
if isinstance(self.inst_db, Engine):
|
||||
self.inst_db_tables = self._introspect_sqlalchemy()
|
||||
elif self.inst_db.scheme == Scheme.SQLITE:
|
||||
self.inst_db_tables = await self._introspect_sqlite()
|
||||
@ -294,7 +294,7 @@ class PluginInstance(DBInstance):
|
||||
"Instance database engine is marked as Postgres, but plugin uses legacy "
|
||||
"database interface, which doesn't support postgres."
|
||||
)
|
||||
self.inst_db = sql.create_engine(f"sqlite:///{self._sqlite_db_path}")
|
||||
self.inst_db = create_engine(f"sqlite:///{self._sqlite_db_path}")
|
||||
elif self.loader.meta.database_type == DatabaseType.ASYNCPG:
|
||||
if self.database_engine is None:
|
||||
if os.path.exists(self._sqlite_db_path) or not self.maubot.plugin_postgres_db:
|
||||
@ -329,7 +329,7 @@ class PluginInstance(DBInstance):
|
||||
async def stop_database(self) -> None:
|
||||
if isinstance(self.inst_db, Database):
|
||||
await self.inst_db.stop()
|
||||
elif isinstance(self.inst_db, sql.engine.Engine):
|
||||
elif isinstance(self.inst_db, Engine):
|
||||
self.inst_db.dispose()
|
||||
else:
|
||||
raise RuntimeError(f"Unknown database type {type(self.inst_db).__name__}")
|
||||
|
19
maubot/lib/optionalalchemy.py
Normal file
19
maubot/lib/optionalalchemy.py
Normal file
@ -0,0 +1,19 @@
|
||||
try:
|
||||
from sqlalchemy import MetaData, asc, create_engine, desc
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
except ImportError:
|
||||
|
||||
class FakeError(Exception):
|
||||
pass
|
||||
|
||||
class FakeType:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise Exception("SQLAlchemy is not installed")
|
||||
|
||||
def create_engine(*args, **kwargs):
|
||||
raise Exception("SQLAlchemy is not installed")
|
||||
|
||||
MetaData = Engine = FakeType
|
||||
IntegrityError = OperationalError = FakeError
|
||||
asc = desc = lambda a: a
|
@ -31,7 +31,7 @@ from ..config import Config
|
||||
from ..lib.zipimport import ZipImportError, zipimporter
|
||||
from ..plugin_base import Plugin
|
||||
from .abc import IDConflictError, PluginClass, PluginLoader
|
||||
from .meta import PluginMeta
|
||||
from .meta import DatabaseType, PluginMeta
|
||||
|
||||
current_version = Version(__version__)
|
||||
yaml = YAML()
|
||||
@ -155,9 +155,9 @@ class ZippedPluginLoader(PluginLoader):
|
||||
return file, meta
|
||||
|
||||
@classmethod
|
||||
def verify_meta(cls, source) -> tuple[str, Version]:
|
||||
def verify_meta(cls, source) -> tuple[str, Version, DatabaseType | None]:
|
||||
_, meta = cls._read_meta(source)
|
||||
return meta.id, meta.version
|
||||
return meta.id, meta.version, meta.database_type if meta.database else None
|
||||
|
||||
def _load_meta(self) -> None:
|
||||
file, meta = self._read_meta(self.path)
|
||||
|
@ -19,12 +19,12 @@ from datetime import datetime
|
||||
|
||||
from aiohttp import web
|
||||
from asyncpg import PostgresError
|
||||
from sqlalchemy import asc, desc, engine, exc
|
||||
import aiosqlite
|
||||
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
from ...instance import PluginInstance
|
||||
from ...lib.optionalalchemy import Engine, IntegrityError, OperationalError, asc, desc
|
||||
from .base import routes
|
||||
from .responses import resp
|
||||
|
||||
@ -66,7 +66,7 @@ async def get_table(request: web.Request) -> web.Response:
|
||||
except KeyError:
|
||||
order = []
|
||||
limit = int(request.query.get("limit", "100"))
|
||||
if isinstance(instance.inst_db, engine.Engine):
|
||||
if isinstance(instance.inst_db, Engine):
|
||||
return _execute_query_sqlalchemy(instance, table.select().order_by(*order).limit(limit))
|
||||
|
||||
|
||||
@ -84,7 +84,7 @@ async def query(request: web.Request) -> web.Response:
|
||||
except KeyError:
|
||||
return resp.query_missing
|
||||
rows_as_dict = data.get("rows_as_dict", False)
|
||||
if isinstance(instance.inst_db, engine.Engine):
|
||||
if isinstance(instance.inst_db, Engine):
|
||||
return _execute_query_sqlalchemy(instance, sql_query, rows_as_dict)
|
||||
elif isinstance(instance.inst_db, Database):
|
||||
try:
|
||||
@ -133,12 +133,12 @@ async def _execute_query_asyncpg(
|
||||
def _execute_query_sqlalchemy(
|
||||
instance: PluginInstance, sql_query: str, rows_as_dict: bool = False
|
||||
) -> web.Response:
|
||||
assert isinstance(instance.inst_db, engine.Engine)
|
||||
assert isinstance(instance.inst_db, Engine)
|
||||
try:
|
||||
res = instance.inst_db.execute(sql_query)
|
||||
except exc.IntegrityError as e:
|
||||
except IntegrityError as e:
|
||||
return resp.sql_integrity_error(e, sql_query)
|
||||
except exc.OperationalError as e:
|
||||
except OperationalError as e:
|
||||
return resp.sql_operational_error(e, sql_query)
|
||||
data = {
|
||||
"ok": True,
|
||||
|
@ -23,10 +23,17 @@ import traceback
|
||||
from aiohttp import web
|
||||
from packaging.version import Version
|
||||
|
||||
from ...loader import MaubotZipImportError, PluginLoader, ZippedPluginLoader
|
||||
from ...loader import DatabaseType, MaubotZipImportError, PluginLoader, ZippedPluginLoader
|
||||
from .base import get_config, routes
|
||||
from .responses import resp
|
||||
|
||||
try:
|
||||
import sqlalchemy
|
||||
|
||||
has_alchemy = True
|
||||
except ImportError:
|
||||
has_alchemy = False
|
||||
|
||||
log = logging.getLogger("maubot.server.upload")
|
||||
|
||||
|
||||
@ -36,9 +43,11 @@ async def put_plugin(request: web.Request) -> web.Response:
|
||||
content = await request.read()
|
||||
file = BytesIO(content)
|
||||
try:
|
||||
pid, version = ZippedPluginLoader.verify_meta(file)
|
||||
pid, version, db_type = ZippedPluginLoader.verify_meta(file)
|
||||
except MaubotZipImportError as e:
|
||||
return resp.plugin_import_error(str(e), traceback.format_exc())
|
||||
if db_type == DatabaseType.SQLALCHEMY and not has_alchemy:
|
||||
return resp.sqlalchemy_not_installed
|
||||
if pid != plugin_id:
|
||||
return resp.pid_mismatch
|
||||
plugin = PluginLoader.id_cache.get(plugin_id, None)
|
||||
@ -55,9 +64,11 @@ async def upload_plugin(request: web.Request) -> web.Response:
|
||||
content = await request.read()
|
||||
file = BytesIO(content)
|
||||
try:
|
||||
pid, version = ZippedPluginLoader.verify_meta(file)
|
||||
pid, version, db_type = ZippedPluginLoader.verify_meta(file)
|
||||
except MaubotZipImportError as e:
|
||||
return resp.plugin_import_error(str(e), traceback.format_exc())
|
||||
if db_type == DatabaseType.SQLALCHEMY and not has_alchemy:
|
||||
return resp.sqlalchemy_not_installed
|
||||
plugin = PluginLoader.id_cache.get(pid, None)
|
||||
if not plugin:
|
||||
return await upload_new_plugin(content, pid, version)
|
||||
|
@ -15,13 +15,16 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from http import HTTPStatus
|
||||
|
||||
from aiohttp import web
|
||||
from asyncpg import PostgresError
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
import aiosqlite
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
|
||||
|
||||
class _Response:
|
||||
@property
|
||||
@ -324,6 +327,16 @@ class _Response:
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def sqlalchemy_not_installed(self) -> web.Response:
|
||||
return web.json_response(
|
||||
{
|
||||
"error": "This plugin requires a legacy database, but SQLAlchemy is not installed",
|
||||
"errcode": "unsupported_plugin_database",
|
||||
},
|
||||
status=HTTPStatus.NOT_IMPLEMENTED,
|
||||
)
|
||||
|
||||
@property
|
||||
def table_not_found(self) -> web.Response:
|
||||
return web.json_response(
|
||||
|
@ -20,7 +20,6 @@ from abc import ABC
|
||||
from asyncio import AbstractEventLoop
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from yarl import URL
|
||||
|
||||
from mautrix.util.async_db import Database, UpgradeTable
|
||||
@ -30,6 +29,8 @@ from mautrix.util.logging import TraceLogger
|
||||
from .scheduler import BasicScheduler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.engine.base import Engine
|
||||
|
||||
from .client import MaubotMatrixClient
|
||||
from .loader import BasePluginLoader
|
||||
from .plugin_server import PluginWebApp
|
||||
@ -56,7 +57,7 @@ class Plugin(ABC):
|
||||
instance_id: str,
|
||||
log: TraceLogger,
|
||||
config: BaseProxyConfig | None,
|
||||
database: Engine | None,
|
||||
database: Engine | Database | None,
|
||||
webapp: PluginWebApp | None,
|
||||
webapp_url: str | None,
|
||||
loader: BasePluginLoader,
|
||||
|
@ -9,3 +9,6 @@ unpaddedbase64>=1,<3
|
||||
#/testing
|
||||
pytest
|
||||
pytest-asyncio
|
||||
|
||||
#/legacydb
|
||||
SQLAlchemy>1,<1.4
|
||||
|
@ -1,7 +1,6 @@
|
||||
mautrix>=0.20.6,<0.21
|
||||
aiohttp>=3,<4
|
||||
yarl>=1,<2
|
||||
SQLAlchemy>=1,<1.4
|
||||
asyncpg>=0.20,<0.30
|
||||
aiosqlite>=0.16,<0.21
|
||||
commonmark>=0.9,<1
|
||||
|
Loading…
Reference in New Issue
Block a user