Switch to asyncpg for database

This commit is contained in:
Tulir Asokan 2022-03-26 14:32:18 +02:00
parent 428b471fec
commit 18ef939a04
7 changed files with 470 additions and 310 deletions

26
.github/workflows/python-lint.yml vendored Normal file
View File

@ -0,0 +1,26 @@
name: Python lint
on: [push, pull_request]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: "3.10"
- uses: isort/isort-action@master
with:
sortPaths: "./rss"
- uses: psf/black@stable
with:
src: "./rss"
version: "22.1.0"
- name: pre-commit
run: |
pip install pre-commit
pre-commit run -av trailing-whitespace
pre-commit run -av end-of-file-fixer
pre-commit run -av check-yaml
pre-commit run -av check-added-large-files

23
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,23 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
hooks:
- id: trailing-whitespace
exclude_types: [markdown]
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
# TODO convert to use the upstream psf/black when
# https://github.com/psf/black/issues/2493 gets fixed
- repo: local
hooks:
- id: black
name: black
entry: black --check
language: system
files: ^rss/.*\.py$
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
hooks:
- id: isort
files: ^rss/.*$

View File

@ -1,4 +1,4 @@
maubot: 0.1.0
maubot: 0.3.0
id: xyz.maubot.rss
version: 0.2.6
license: AGPL-3.0-or-later
@ -10,3 +10,4 @@ extra_files:
dependencies:
- feedparser>=5.1
database: true
database_type: asyncpg

12
pyproject.toml Normal file
View File

@ -0,0 +1,12 @@
[tool.isort]
profile = "black"
force_to_top = "typing"
from_first = true
combine_as_imports = true
known_first_party = ["mautrix", "maubot"]
line_length = 99
[tool.black]
line-length = 99
target-version = ["py38"]
required-version = "22.1.0"

View File

@ -1,5 +1,5 @@
# rss - A maubot plugin to subscribe to RSS/Atom feeds.
# Copyright (C) 2021 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,23 +13,34 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type, List, Any, Dict, Tuple, Awaitable, Iterable, Optional
from __future__ import annotations
from typing import Any, Iterable
from datetime import datetime
from time import mktime, time
from string import Template
from time import mktime, time
import asyncio
import hashlib
import aiohttp
import hashlib
import attr
import feedparser
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
from mautrix.types import (StateEvent, EventType, MessageType, RoomID, EventID,
PowerLevelStateEventContent)
from maubot import Plugin, MessageEvent
from maubot import MessageEvent, Plugin
from maubot.handlers import command, event
from mautrix.types import (
EventID,
EventType,
MessageType,
PowerLevelStateEventContent,
RoomID,
StateEvent,
)
from mautrix.util.async_db import UpgradeTable
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
from .db import Database, Feed, Entry, Subscription
from .db import DBManager, Entry, Feed, Subscription
from .migrations import upgrade_table
rss_change_level = EventType.find("xyz.maubot.rss", t_class=EventType.Class.STATE)
@ -47,7 +58,7 @@ class BoolArgument(command.Argument):
def __init__(self, name: str, label: str = None, *, required: bool = False) -> None:
super().__init__(name, label, required=required, pass_raw=False)
def match(self, val: str, **kwargs) -> Tuple[str, Any]:
def match(self, val: str, **kwargs) -> tuple[str, Any]:
part = val.split(" ")[0].lower()
if part in ("f", "false", "n", "no", "0"):
res = False
@ -55,23 +66,27 @@ class BoolArgument(command.Argument):
res = True
else:
raise ValueError("invalid boolean")
return val[len(part):], res
return val[len(part) :], res
class RSSBot(Plugin):
db: Database
dbm: DBManager
poll_task: asyncio.Future
http: aiohttp.ClientSession
power_level_cache: Dict[RoomID, Tuple[int, PowerLevelStateEventContent]]
power_level_cache: dict[RoomID, tuple[int, PowerLevelStateEventContent]]
@classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]:
def get_config_class(cls) -> type[BaseProxyConfig]:
return Config
@classmethod
def get_db_upgrade_table(cls) -> UpgradeTable:
return upgrade_table
async def start(self) -> None:
await super().start()
self.config.load_and_update()
self.db = Database(self.database)
self.dbm = DBManager(self.database)
self.http = self.client.api.session
self.power_level_cache = {}
self.poll_task = asyncio.ensure_future(self.poll_feeds(), loop=self.loop)
@ -89,21 +104,26 @@ class RSSBot(Plugin):
self.log.exception("Fatal error while polling feeds")
async def _send(self, feed: Feed, entry: Entry, sub: Subscription) -> EventID:
message = sub.notification_template.safe_substitute({
"feed_url": feed.url,
"feed_title": feed.title,
"feed_subtitle": feed.subtitle,
"feed_link": feed.link,
**entry._asdict(),
})
message = sub.notification_template.safe_substitute(
{
"feed_url": feed.url,
"feed_title": feed.title,
"feed_subtitle": feed.subtitle,
"feed_link": feed.link,
**attr.asdict(entry),
}
)
msgtype = MessageType.NOTICE if sub.send_notice else MessageType.TEXT
try:
return await self.client.send_markdown(sub.room_id, message, msgtype=msgtype,
allow_html=True)
return await self.client.send_markdown(
sub.room_id, message, msgtype=msgtype, allow_html=True
)
except Exception as e:
self.log.warning(f"Failed to send {entry.id} of {feed.id} to {sub.room_id}: {e}")
async def _broadcast(self, feed: Feed, entry: Entry, subscriptions: List[Subscription]) -> None:
async def _broadcast(
self, feed: Feed, entry: Entry, subscriptions: list[Subscription]
) -> None:
self.log.debug(f"Broadcasting {entry.id} of {feed.id}")
spam_sleep = self.config["spam_sleep"]
tasks = [self._send(feed, entry, sub) for sub in subscriptions]
@ -115,7 +135,7 @@ class RSSBot(Plugin):
await asyncio.gather(*tasks)
async def _poll_once(self) -> None:
subs = self.db.get_feeds()
subs = await self.dbm.get_feeds()
if not subs:
return
now = int(time())
@ -125,30 +145,33 @@ class RSSBot(Plugin):
self.log.info(f"Polling {len(tasks)} feeds")
for res in asyncio.as_completed(tasks):
feed, entries = await res
self.log.trace(f"Fetching {feed.id} (backoff: {feed.error_count} / {feed.next_retry}) "
f"success: {bool(entries)}")
self.log.trace(
f"Fetching {feed.id} (backoff: {feed.error_count} / {feed.next_retry}) "
f"success: {bool(entries)}"
)
if not entries:
error_count = feed.error_count + 1
next_retry_delay = self.config["update_interval"] * 60 * error_count
next_retry_delay = min(next_retry_delay, self.config["max_backoff"] * 60)
next_retry = int(time() + next_retry_delay)
self.log.debug(f"Setting backoff of {feed.id} to {error_count} / {next_retry}")
self.db.set_backoff(feed, error_count, next_retry)
await self.dbm.set_backoff(feed, error_count, next_retry)
continue
elif feed.error_count > 0:
self.log.debug(f"Resetting backoff of {feed.id}")
self.db.set_backoff(feed, error_count=0, next_retry=0)
await self.dbm.set_backoff(feed, error_count=0, next_retry=0)
try:
new_entries = {entry.id: entry for entry in entries}
except Exception:
self.log.exception(f"Weird error in items of {feed.url}")
continue
for old_entry in self.db.get_entries(feed.id):
for old_entry in await self.dbm.get_entries(feed.id):
new_entries.pop(old_entry.id, None)
self.log.trace(f"Feed {feed.id} had {len(new_entries)} new entries")
self.db.add_entries(new_entries.values())
# TODO sort properly?
for entry in reversed(new_entries.values()):
new_entry_list: list[Entry] = list(new_entries.values())
new_entry_list.sort(key=lambda entry: (entry.date, entry.id))
await self.dbm.add_entries(new_entry_list)
for entry in new_entry_list:
await self._broadcast(feed, entry, feed.subscriptions)
self.log.info(f"Finished polling {len(tasks)} feeds")
@ -163,21 +186,24 @@ class RSSBot(Plugin):
self.log.exception("Error while polling feeds")
await asyncio.sleep(self.config["update_interval"] * 60, loop=self.loop)
async def try_parse_feed(self, feed: Optional[Feed] = None) -> Tuple[Feed, Iterable[Entry]]:
async def try_parse_feed(self, feed: Feed | None = None) -> tuple[Feed, list[Entry]]:
try:
self.log.trace(f"Trying to fetch {feed.id} / {feed.url} "
f"(backoff: {feed.error_count} / {feed.next_retry})")
self.log.trace(
f"Trying to fetch {feed.id} / {feed.url} "
f"(backoff: {feed.error_count} / {feed.next_retry})"
)
return await self.parse_feed(feed=feed)
except Exception as e:
self.log.warning(f"Failed to parse feed {feed.id} / {feed.url}: {e}")
return feed, []
async def parse_feed(self, *, feed: Optional[Feed] = None, url: Optional[str] = None
) -> Tuple[Feed, Iterable[Entry]]:
async def parse_feed(
self, *, feed: Feed | None = None, url: str | None = None
) -> tuple[Feed, list[Entry]]:
if feed is None:
if url is None:
raise ValueError("Either feed or url must be set")
feed = Feed(-1, url, "", "", "", 0, 0, [])
feed = Feed(id=-1, url=url, title="", subtitle="", link="")
elif url is not None:
raise ValueError("Only one of feed or url must be set")
resp = await self.http.get(feed.url)
@ -188,38 +214,40 @@ class RSSBot(Plugin):
return await self._parse_rss(feed, resp)
@classmethod
async def _parse_json(cls, feed: Feed, resp: aiohttp.ClientResponse
) -> Tuple[Feed, Iterable[Entry]]:
async def _parse_json(
cls, feed: Feed, resp: aiohttp.ClientResponse
) -> tuple[Feed, list[Entry]]:
content = await resp.json()
if content["version"] not in ("https://jsonfeed.org/version/1",
"https://jsonfeed.org/version/1.1"):
if content["version"] not in (
"https://jsonfeed.org/version/1",
"https://jsonfeed.org/version/1.1",
):
raise ValueError("Unsupported JSON feed version")
if not isinstance(content["items"], list):
raise ValueError("Feed is not a valid JSON feed (items is not a list)")
feed = Feed(id=feed.id, title=content["title"], subtitle=content.get("subtitle", ""),
url=feed.url, link=content.get("home_page_url", ""),
next_retry=feed.next_retry, error_count=feed.error_count,
subscriptions=feed.subscriptions)
return feed, (cls._parse_json_entry(feed.id, entry) for entry in content["items"])
feed.title = content["title"]
feed.subtitle = content.get("subtitle", "")
feed.link = content.get("home_page_url", "")
return feed, [cls._parse_json_entry(feed.id, entry) for entry in content["items"]]
@classmethod
def _parse_json_entry(cls, feed_id: int, entry: Dict[str, Any]) -> Entry:
def _parse_json_entry(cls, feed_id: int, entry: dict[str, Any]) -> Entry:
try:
date = datetime.fromisoformat(entry["date_published"])
except (ValueError, KeyError):
date = datetime.now()
title = entry.get("title", "")
summary = (entry.get("summary")
or entry.get("content_html")
or entry.get("content_text")
or "").strip()
summary = (
entry.get("summary") or entry.get("content_html") or entry.get("content_text") or ""
).strip()
id = str(entry["id"])
link = entry.get("url") or id
return Entry(feed_id=feed_id, id=id, date=date, title=title, summary=summary, link=link)
@classmethod
async def _parse_rss(cls, feed: Feed, resp: aiohttp.ClientResponse
) -> Tuple[Feed, Iterable[Entry]]:
async def _parse_rss(
cls, feed: Feed, resp: aiohttp.ClientResponse
) -> tuple[Feed, list[Entry]]:
try:
content = await resp.text()
except UnicodeDecodeError:
@ -233,21 +261,27 @@ class RSSBot(Plugin):
if not isinstance(parsed_data.bozo_exception, feedparser.ThingsNobodyCaresAboutButMe):
raise parsed_data.bozo_exception
feed_data = parsed_data.get("feed", {})
feed = Feed(id=feed.id, url=feed.url, title=feed_data.get("title", feed.url),
subtitle=feed_data.get("description", ""), link=feed_data.get("link", ""),
error_count=feed.error_count, next_retry=feed.next_retry,
subscriptions=feed.subscriptions)
return feed, (cls._parse_rss_entry(feed.id, entry) for entry in parsed_data.entries)
feed.title = feed_data.get("title", feed.url)
feed.subtitle = feed_data.get("description", "")
feed.link = feed_data.get("link", "")
return feed, [cls._parse_rss_entry(feed.id, entry) for entry in parsed_data.entries]
@classmethod
def _parse_rss_entry(cls, feed_id: int, entry: Any) -> Entry:
return Entry(
feed_id=feed_id,
id=(getattr(entry, "id", None) or
hashlib.sha1(" ".join([getattr(entry, "title", ""),
getattr(entry, "description", ""),
getattr(entry, "link", "")]).encode("utf-8")
).hexdigest()),
id=(
getattr(entry, "id", None)
or hashlib.sha1(
" ".join(
[
getattr(entry, "title", ""),
getattr(entry, "description", ""),
getattr(entry, "link", ""),
]
).encode("utf-8")
).hexdigest()
),
date=cls._parse_rss_date(entry),
title=getattr(entry, "title", ""),
summary=getattr(entry, "description", "").strip(),
@ -286,109 +320,138 @@ class RSSBot(Plugin):
if not isinstance(state_level, int):
state_level = 50
if user_level < state_level:
await evt.reply("You don't have the permission to "
"manage the subscriptions of this room.")
await evt.reply(
"You don't have the permission to manage the subscriptions of this room."
)
return False
return True
@command.new(name=lambda self: self.config["command_prefix"],
help="Manage this RSS bot", require_subcommand=True)
@command.new(
name=lambda self: self.config["command_prefix"],
help="Manage this RSS bot",
require_subcommand=True,
)
async def rss(self) -> None:
pass
@rss.subcommand("subscribe", aliases=("s", "sub"),
help="Subscribe this room to a feed.")
@rss.subcommand("subscribe", aliases=("s", "sub"), help="Subscribe this room to a feed.")
@command.argument("url", "feed URL", pass_raw=True)
async def subscribe(self, evt: MessageEvent, url: str) -> None:
if not await self.can_manage(evt):
return
feed = self.db.get_feed_by_url(url)
feed = await self.dbm.get_feed_by_url(url)
if not feed:
try:
info, entries = await self.parse_feed(url=url)
except Exception as e:
await evt.reply(f"Failed to load feed: {e}")
return
feed = self.db.create_feed(info)
self.db.add_entries(entries, override_feed_id=feed.id)
feed = await self.dbm.create_feed(info)
await self.dbm.add_entries(entries, override_feed_id=feed.id)
elif feed.error_count > 0:
self.db.set_backoff(feed, error_count=feed.error_count, next_retry=0)
await self.dbm.set_backoff(feed, error_count=feed.error_count, next_retry=0)
feed_info = f"feed ID {feed.id}: [{feed.title}]({feed.url})"
sub, _ = self.db.get_subscription(feed.id, evt.room_id)
sub, _ = await self.dbm.get_subscription(feed.id, evt.room_id)
if sub is not None:
subscriber = ("You" if sub.user_id == evt.sender
else f"[{sub.user_id}](https://matrix.to/#/{sub.user_id})")
subscriber = (
"You"
if sub.user_id == evt.sender
else f"[{sub.user_id}](https://matrix.to/#/{sub.user_id})"
)
await evt.reply(f"{subscriber} had already subscribed this room to {feed_info}")
else:
self.db.subscribe(feed.id, evt.room_id, evt.sender)
await self.dbm.subscribe(feed.id, evt.room_id, evt.sender)
await evt.reply(f"Subscribed to {feed_info}")
@rss.subcommand("unsubscribe", aliases=("u", "unsub"),
help="Unsubscribe this room from a feed.")
@rss.subcommand(
"unsubscribe", aliases=("u", "unsub"), help="Unsubscribe this room from a feed."
)
@command.argument("feed_id", "feed ID", parser=int)
async def unsubscribe(self, evt: MessageEvent, feed_id: int) -> None:
if not await self.can_manage(evt):
return
sub, feed = self.db.get_subscription(feed_id, evt.room_id)
sub, feed = await self.dbm.get_subscription(feed_id, evt.room_id)
if not sub:
await evt.reply("This room is not subscribed to that feed")
return
self.db.unsubscribe(feed.id, evt.room_id)
await self.dbm.unsubscribe(feed.id, evt.room_id)
await evt.reply(f"Unsubscribed from feed ID {feed.id}: [{feed.title}]({feed.url})")
@rss.subcommand("template", aliases=("t", "tpl"),
help="Change the notification template for a subscription in this room")
@rss.subcommand(
"template",
aliases=("t", "tpl"),
help="Change the notification template for a subscription in this room",
)
@command.argument("feed_id", "feed ID", parser=int)
@command.argument("template", "new template", pass_raw=True)
async def command_template(self, evt: MessageEvent, feed_id: int, template: str) -> None:
if not await self.can_manage(evt):
return
sub, feed = self.db.get_subscription(feed_id, evt.room_id)
sub, feed = await self.dbm.get_subscription(feed_id, evt.room_id)
if not sub:
await evt.reply("This room is not subscribed to that feed")
return
self.db.update_template(feed.id, evt.room_id, template)
sub = Subscription(feed_id=feed.id, room_id=sub.room_id, user_id=sub.user_id,
notification_template=Template(template), send_notice=sub.send_notice)
sample_entry = Entry(feed.id, "SAMPLE", datetime.now(), "Sample entry",
"This is a sample entry to demonstrate your new template",
"http://example.com")
await self.dbm.update_template(feed.id, evt.room_id, template)
sub = Subscription(
feed_id=feed.id,
room_id=sub.room_id,
user_id=sub.user_id,
notification_template=Template(template),
send_notice=sub.send_notice,
)
sample_entry = Entry(
feed_id=feed.id,
id="SAMPLE",
date=datetime.now(),
title="Sample entry",
summary="This is a sample entry to demonstrate your new template",
link="http://example.com",
)
await evt.reply(f"Template for feed ID {feed.id} updated. Sample notification:")
await self._send(feed, sample_entry, sub)
@rss.subcommand("notice", aliases=("n",),
help="Set whether or not the bot should send updates as m.notice")
@rss.subcommand(
"notice", aliases=("n",), help="Set whether or not the bot should send updates as m.notice"
)
@command.argument("feed_id", "feed ID", parser=int)
@BoolArgument("setting", "true/false")
async def command_notice(self, evt: MessageEvent, feed_id: int, setting: bool) -> None:
if not await self.can_manage(evt):
return
sub, feed = self.db.get_subscription(feed_id, evt.room_id)
sub, feed = await self.dbm.get_subscription(feed_id, evt.room_id)
if not sub:
await evt.reply("This room is not subscribed to that feed")
return
self.db.set_send_notice(feed.id, evt.room_id, setting)
await self.dbm.set_send_notice(feed.id, evt.room_id, setting)
send_type = "m.notice" if setting else "m.text"
await evt.reply(f"Updates for feed ID {feed.id} will now be sent as `{send_type}`")
@staticmethod
def _format_subscription(feed: Feed, subscriber: str) -> str:
msg = (f"* {feed.id} - [{feed.title}]({feed.url}) "
f"(subscribed by [{subscriber}](https://matrix.to/#/{subscriber}))")
msg = (
f"* {feed.id} - [{feed.title}]({feed.url}) "
f"(subscribed by [{subscriber}](https://matrix.to/#/{subscriber}))"
)
if feed.error_count > 1:
msg += f" \n ⚠️ The last {feed.error_count} attempts to fetch the feed have failed!"
return msg
@rss.subcommand("subscriptions", aliases=("ls", "list", "subs"),
help="List the subscriptions in the current room.")
@rss.subcommand(
"subscriptions",
aliases=("ls", "list", "subs"),
help="List the subscriptions in the current room.",
)
async def command_subscriptions(self, evt: MessageEvent) -> None:
subscriptions = self.db.get_feeds_by_room(evt.room_id)
await evt.reply("**Subscriptions in this room:**\n\n"
+ "\n".join(self._format_subscription(feed, subscriber)
for feed, subscriber in subscriptions))
subscriptions = await self.dbm.get_feeds_by_room(evt.room_id)
await evt.reply(
"**Subscriptions in this room:**\n\n"
+ "\n".join(
self._format_subscription(feed, subscriber) for feed, subscriber in subscriptions
)
)
@event.on(EventType.ROOM_TOMBSTONE)
async def tombstone(self, evt: StateEvent) -> None:
if not evt.content.replacement_room:
return
self.db.update_room_id(evt.room_id, evt.content.replacement_room)
await self.dbm.update_room_id(evt.room_id, evt.content.replacement_room)

374
rss/db.py
View File

@ -1,5 +1,5 @@
# rss - A maubot plugin to subscribe to RSS/Atom feeds.
# Copyright (C) 2020 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,221 +13,207 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Iterable, NamedTuple, List, Optional, Dict, Tuple
from __future__ import annotations
from datetime import datetime
from string import Template
from sqlalchemy import (Column, String, Integer, DateTime, Text, Boolean, ForeignKey,
Table, MetaData,
select, and_, true)
from sqlalchemy.engine.base import Engine
from asyncpg import Record
from attr import dataclass
import attr
from mautrix.types import UserID, RoomID
Subscription = NamedTuple("Subscription", feed_id=int, room_id=RoomID, user_id=UserID,
notification_template=Template, send_notice=bool)
Feed = NamedTuple("Feed", id=int, url=str, title=str, subtitle=str, link=str, next_retry=int,
error_count=int, subscriptions=List[Subscription])
Entry = NamedTuple("Entry", feed_id=int, id=str, date=datetime, title=str, summary=str, link=str)
from mautrix.types import RoomID, UserID
from mautrix.util.async_db import Database, Scheme
class Database:
db: Engine
feed: Table
subscription: Table
entry: Table
version: Table
@dataclass
class Subscription:
feed_id: int
room_id: RoomID
user_id: UserID
notification_template: Template
send_notice: bool
def __init__(self, db: Engine) -> None:
@classmethod
def from_row(cls, row: Record | None) -> Subscription | None:
if not row:
return None
feed_id = row["id"]
room_id = row["room_id"]
user_id = row["user_id"]
if not room_id or not user_id:
return None
send_notice = bool(row["send_notice"])
tpl = Template(row["notification_template"])
return cls(
feed_id=feed_id,
room_id=room_id,
user_id=user_id,
notification_template=tpl,
send_notice=send_notice,
)
@dataclass
class Feed:
id: int
url: str
title: str
subtitle: str
link: str
next_retry: int = 0
error_count: int = 0
subscriptions: list[Subscription] = attr.ib(factory=lambda: [])
@classmethod
def from_row(cls, row: Record | None) -> Feed | None:
if not row:
return None
data = {**row}
data.pop("room_id", None)
data.pop("user_id", None)
data.pop("send_notice", None)
data.pop("notification_template", None)
return cls(**data, subscriptions=[])
date_fmt = "%Y-%m-%d %H:%M:%S"
date_fmt_microseconds = "%Y-%m-%d %H:%M:%S.%f"
@dataclass
class Entry:
feed_id: int
id: str
date: datetime
title: str
summary: str
link: str
@classmethod
def from_row(cls, row: Record | None) -> Entry | None:
if not row:
return None
data = {**row}
date = data.pop("date")
if not isinstance(date, datetime):
try:
date = datetime.strptime(date, date_fmt_microseconds if "." in date else date_fmt)
except ValueError:
date = datetime.now()
return cls(**data, date=date)
class DBManager:
db: Database
def __init__(self, db: Database) -> None:
self.db = db
metadata = MetaData()
self.feed = Table("feed", metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("url", Text, nullable=False, unique=True),
Column("title", Text, nullable=False),
Column("subtitle", Text, nullable=False),
Column("link", Text, nullable=False),
Column("next_retry", Integer, nullable=False),
Column("error_count", Integer, nullable=False))
self.subscription = Table("subscription", metadata,
Column("feed_id", Integer, ForeignKey("feed.id"),
primary_key=True),
Column("room_id", String(255), primary_key=True),
Column("user_id", String(255), nullable=False),
Column("notification_template", String(255), nullable=True),
Column("send_notice", Boolean, nullable=False,
server_default=true()))
self.entry = Table("entry", metadata,
Column("feed_id", Integer, ForeignKey("feed.id"), primary_key=True),
Column("id", String(255), primary_key=True),
Column("date", DateTime, nullable=False),
Column("title", Text, nullable=False),
Column("summary", Text, nullable=False),
Column("link", Text, nullable=False))
self.version = Table("version", metadata,
Column("version", Integer, primary_key=True))
self.upgrade()
def upgrade(self) -> None:
self.db.execute("CREATE TABLE IF NOT EXISTS version (version INTEGER PRIMARY KEY)")
try:
version, = next(self.db.execute(select([self.version.c.version])))
except (StopIteration, IndexError):
version = 0
if version == 0:
self.db.execute("""CREATE TABLE IF NOT EXISTS feed (
id INTEGER NOT NULL,
url TEXT NOT NULL,
title TEXT NOT NULL,
subtitle TEXT NOT NULL,
link TEXT NOT NULL,
PRIMARY KEY (id),
UNIQUE (url)
)""")
self.db.execute("""CREATE TABLE IF NOT EXISTS subscription (
feed_id INTEGER NOT NULL,
room_id VARCHAR(255) NOT NULL,
user_id VARCHAR(255) NOT NULL,
notification_template VARCHAR(255),
PRIMARY KEY (feed_id, room_id),
FOREIGN KEY(feed_id) REFERENCES feed (id)
)""")
self.db.execute("""CREATE TABLE IF NOT EXISTS entry (
feed_id INTEGER NOT NULL,
id VARCHAR(255) NOT NULL,
date DATETIME NOT NULL,
title TEXT NOT NULL,
summary TEXT NOT NULL,
link TEXT NOT NULL,
PRIMARY KEY (feed_id, id),
FOREIGN KEY(feed_id) REFERENCES feed (id)
)""")
version = 1
if version == 1:
self.db.execute("ALTER TABLE subscription ADD COLUMN send_notice BOOLEAN DEFAULT true")
version = 2
if version == 2:
self.db.execute("ALTER TABLE feed ADD COLUMN next_retry BIGINT DEFAULT 0")
self.db.execute("ALTER TABLE feed ADD COLUMN error_count BIGINT DEFAULT 0")
version = 3
self.db.execute(self.version.delete())
self.db.execute(self.version.insert().values(version=version))
def get_feeds(self) -> Iterable[Feed]:
rows = self.db.execute(select([self.feed,
self.subscription.c.room_id,
self.subscription.c.user_id,
self.subscription.c.notification_template,
self.subscription.c.send_notice])
.where(self.subscription.c.feed_id == self.feed.c.id))
map: Dict[int, Feed] = {}
async def get_feeds(self) -> list[Feed]:
q = """
SELECT id, url, title, subtitle, link, next_retry, error_count,
room_id, user_id, notification_template, send_notice
FROM feed INNER JOIN subscription ON feed.id = subscription.feed_id
"""
rows = await self.db.fetch(q)
feeds: dict[int, Feed] = {}
for row in rows:
(feed_id, url, title, subtitle, link, next_retry, error_count,
room_id, user_id, notification_template, send_notice) = row
map.setdefault(feed_id, Feed(feed_id, url, title, subtitle, link, next_retry,
error_count, subscriptions=[]))
map[feed_id].subscriptions.append(
Subscription(feed_id=feed_id, room_id=room_id, user_id=user_id,
notification_template=Template(notification_template),
send_notice=send_notice))
return map.values()
try:
feed = feeds[row["id"]]
except KeyError:
feed = feeds[row["id"]] = Feed.from_row(row)
feed.subscriptions.append(Subscription.from_row(row))
return list(feeds.values())
def get_feeds_by_room(self, room_id: RoomID) -> Iterable[Tuple[Feed, UserID]]:
return ((Feed(feed_id, url, title, subtitle, link, next_retry, error_count,
subscriptions=[]),
user_id)
for (feed_id, url, title, subtitle, link, next_retry, error_count, user_id) in
self.db.execute(select([self.feed, self.subscription.c.user_id])
.where(and_(self.subscription.c.room_id == room_id,
self.subscription.c.feed_id == self.feed.c.id))))
async def get_feeds_by_room(self, room_id: RoomID) -> list[tuple[Feed, UserID]]:
q = """
SELECT id, url, title, subtitle, link, next_retry, error_count, user_id FROM feed
INNER JOIN subscription ON feed.id = subscription.feed_id AND subscription.room_id = $1
"""
rows = await self.db.fetch(q, room_id)
return [(Feed.from_row(row), row["user_id"]) for row in rows]
def get_rooms_by_feed(self, feed_id: int) -> Iterable[RoomID]:
return (row[0] for row in
self.db.execute(select([self.subscription.c.room_id])
.where(self.subscription.c.feed_id == feed_id)))
async def get_entries(self, feed_id: int) -> list[Entry]:
q = "SELECT feed_id, id, date, title, summary, link FROM entry WHERE feed_id = $1"
return [Entry.from_row(row) for row in await self.db.fetch(q, feed_id)]
def get_entries(self, feed_id: int) -> Iterable[Entry]:
return (Entry(*row) for row in
self.db.execute(select([self.entry]).where(self.entry.c.feed_id == feed_id)))
def add_entries(self, entries: Iterable[Entry], override_feed_id: Optional[int] = None) -> None:
async def add_entries(self, entries: list[Entry], override_feed_id: int | None = None) -> None:
if not entries:
return
entries = [entry._asdict() for entry in entries]
if override_feed_id is not None:
if override_feed_id:
for entry in entries:
entry["feed_id"] = override_feed_id
self.db.execute(self.entry.insert(), entries)
entry.feed_id = override_feed_id
records = [attr.astuple(entry) for entry in entries]
columns = ("feed_id", "id", "date", "title", "summary", "link")
async with self.db.acquire() as conn:
if self.db.scheme == Scheme.POSTGRES:
await conn.copy_records_to_table("entry", records=records, columns=columns)
else:
q = (
"INSERT INTO entry (feed_id, id, date, title, summary, link) "
"VALUES ($1, $2, $3, $4, $5, $6)"
)
await conn.executemany(q, records)
def get_feed_by_url(self, url: str) -> Optional[Feed]:
rows = self.db.execute(select([self.feed]).where(self.feed.c.url == url))
try:
row = next(rows)
return Feed(*row, subscriptions=[])
except (ValueError, StopIteration):
return None
async def get_feed_by_url(self, url: str) -> Feed | None:
q = "SELECT id, url, title, subtitle, link, next_retry, error_count FROM feed WHERE url=$1"
return Feed.from_row(await self.db.fetchrow(q, url))
def get_feed_by_id(self, feed_id: int) -> Optional[Feed]:
rows = self.db.execute(select([self.feed]).where(self.feed.c.id == feed_id))
try:
row = next(rows)
return Feed(*row, subscriptions=[])
except (ValueError, StopIteration):
return None
async def get_subscription(
self, feed_id: int, room_id: RoomID
) -> tuple[Subscription | None, Feed | None]:
q = """
SELECT id, url, title, subtitle, link, next_retry, error_count,
room_id, user_id, notification_template, send_notice
FROM feed LEFT JOIN subscription ON feed.id = subscription.feed_id AND room_id = $2
WHERE feed.id = $1
"""
row = await self.db.fetchrow(q, feed_id, room_id)
return Subscription.from_row(row), Feed.from_row(row)
def get_subscription(self, feed_id: int, room_id: RoomID) -> Tuple[Optional[Subscription],
Optional[Feed]]:
tbl = self.subscription
rows = self.db.execute(select([self.feed, tbl.c.room_id, tbl.c.user_id,
tbl.c.notification_template, tbl.c.send_notice])
.where(and_(tbl.c.feed_id == feed_id, tbl.c.room_id == room_id,
self.feed.c.id == feed_id)))
try:
(feed_id, url, title, subtitle, link, next_retry, error_count,
room_id, user_id, template, send_notice) = next(rows)
notification_template = Template(template)
return (Subscription(feed_id, room_id, user_id, notification_template, send_notice)
if room_id else None,
Feed(feed_id, url, title, subtitle, link, next_retry, error_count, []))
except (ValueError, StopIteration):
return None, None
async def update_room_id(self, old: RoomID, new: RoomID) -> None:
await self.db.execute("UPDATE subscription SET room_id = $1 WHERE room_id = $2", new, old)
def update_room_id(self, old: RoomID, new: RoomID) -> None:
self.db.execute(self.subscription.update()
.where(self.subscription.c.room_id == old)
.values(room_id=new))
async def create_feed(self, info: Feed) -> Feed:
q = (
"INSERT INTO feed (url, title, subtitle, link, next_retry) "
"VALUES ($1, $2, $3, $4, $5) RETURNING (id)"
)
info.id = await self.db.fetchval(
q, info.url, info.title, info.subtitle, info.link, info.next_retry
)
return info
def create_feed(self, info: Feed) -> Feed:
res = self.db.execute(self.feed.insert().values(url=info.url, title=info.title,
subtitle=info.subtitle, link=info.link,
next_retry=info.next_retry))
return Feed(id=res.inserted_primary_key[0], url=info.url, title=info.title,
subtitle=info.subtitle, link=info.link, next_retry=info.next_retry,
error_count=info.error_count, subscriptions=[])
async def set_backoff(self, info: Feed, error_count: int, next_retry: int) -> None:
q = "UPDATE feed SET error_count = $2, next_retry = $3 WHERE id = $1"
await self.db.execute(q, info.id, error_count, next_retry)
def set_backoff(self, info: Feed, error_count: int, next_retry: int) -> None:
self.db.execute(self.feed.update()
.where(self.feed.c.id == info.id)
.values(error_count=error_count, next_retry=next_retry))
async def subscribe(
self,
feed_id: int,
room_id: RoomID,
user_id: UserID,
template: str | None = None,
send_notice: bool = True,
) -> None:
q = """
INSERT INTO subscription (feed_id, room_id, user_id, notification_template, send_notice)
VALUES ($1, $2, $3, $4, $5)
"""
template = template or "New post in $feed_title: [$title]($link)"
await self.db.execute(q, feed_id, room_id, user_id, template, send_notice)
def subscribe(self, feed_id: int, room_id: RoomID, user_id: UserID) -> None:
self.db.execute(self.subscription.insert().values(
feed_id=feed_id, room_id=room_id, user_id=user_id,
notification_template="New post in $feed_title: [$title]($link)"))
async def unsubscribe(self, feed_id: int, room_id: RoomID) -> None:
q = "DELETE FROM subscription WHERE feed_id = $1 AND room_id = $2"
await self.db.execute(q, feed_id, room_id)
def unsubscribe(self, feed_id: int, room_id: RoomID) -> None:
tbl = self.subscription
self.db.execute(tbl.delete().where(and_(tbl.c.feed_id == feed_id,
tbl.c.room_id == room_id)))
async def update_template(self, feed_id: int, room_id: RoomID, template: str) -> None:
q = "UPDATE subscription SET notification_template=$3 WHERE feed_id=$1 AND room_id=$2"
await self.db.execute(q, feed_id, room_id, template)
def update_template(self, feed_id: int, room_id: RoomID, template: str) -> None:
tbl = self.subscription
self.db.execute(tbl.update()
.where(and_(tbl.c.feed_id == feed_id, tbl.c.room_id == room_id))
.values(notification_template=template))
def set_send_notice(self, feed_id: int, room_id: RoomID, send_notice: bool) -> None:
tbl = self.subscription
self.db.execute(tbl.update()
.where(and_(tbl.c.feed_id == feed_id, tbl.c.room_id == room_id))
.values(send_notice=send_notice))
async def set_send_notice(self, feed_id: int, room_id: RoomID, send_notice: bool) -> None:
q = "UPDATE subscription SET send_notice=$3 WHERE feed_id=$1 AND room_id=$2"
await self.db.execute(q, feed_id, room_id, send_notice)

View File

@ -1,5 +1,5 @@
# rss - A maubot plugin to subscribe to RSS/Atom feeds.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,13 +13,62 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import select
from sqlalchemy.engine.base import Engine
from alembic.migration import MigrationContext
from alembic.operations import Operations
from mautrix.util.async_db import Connection, Scheme, UpgradeTable
upgrade_table = UpgradeTable()
def run(engine: Engine):
conn = engine.connect()
ctx = MigrationContext.configure(conn)
op = Operations(ctx)
@upgrade_table.register(description="Latest revision", upgrades_to=3)
async def upgrade_latest(conn: Connection, scheme: Scheme) -> None:
gen = "GENERATED ALWAYS AS IDENTITY" if scheme != Scheme.SQLITE else ""
await conn.execute(
f"""CREATE TABLE IF NOT EXISTS feed (
id INTEGER {gen},
url TEXT NOT NULL,
title TEXT NOT NULL,
subtitle TEXT NOT NULL,
link TEXT NOT NULL,
next_retry BIGINT DEFAULT 0,
error_count BIGINT DEFAULT 0,
PRIMARY KEY (id),
UNIQUE (url)
)"""
)
await conn.execute(
"""CREATE TABLE IF NOT EXISTS subscription (
feed_id INTEGER,
room_id TEXT,
user_id TEXT NOT NULL,
notification_template TEXT,
send_notice BOOLEAN DEFAULT true,
PRIMARY KEY (feed_id, room_id),
FOREIGN KEY (feed_id) REFERENCES feed (id)
)"""
)
await conn.execute(
"""CREATE TABLE entry (
feed_id INTEGER,
id TEXT,
date timestamp NOT NULL,
title TEXT NOT NULL,
summary TEXT NOT NULL,
link TEXT NOT NULL,
PRIMARY KEY (feed_id, id),
FOREIGN KEY (feed_id) REFERENCES feed (id)
);"""
)
@upgrade_table.register(description="Add send_notice field to subscriptions")
async def upgrade_v2(conn: Connection) -> None:
await conn.execute("ALTER TABLE subscription ADD COLUMN send_notice BOOLEAN DEFAULT true")
@upgrade_table.register(description="Add error counts to feeds")
async def upgrade_v3(conn: Connection) -> None:
await conn.execute("ALTER TABLE feed ADD COLUMN next_retry BIGINT DEFAULT 0")
await conn.execute("ALTER TABLE feed ADD COLUMN error_count BIGINT DEFAULT 0")