maubot-rss/rss/db.py

234 lines
12 KiB
Python
Raw Normal View History

2018-11-26 19:33:41 -05:00
# rss - A maubot plugin to subscribe to RSS/Atom feeds.
2020-05-20 18:10:25 -04:00
# Copyright (C) 2020 Tulir Asokan
2018-11-26 19:33:41 -05:00
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Iterable, NamedTuple, List, Optional, Dict, Tuple
2018-11-26 19:33:41 -05:00
from datetime import datetime
from string import Template
2018-11-26 19:33:41 -05:00
2020-07-01 10:20:14 -04:00
from sqlalchemy import (Column, String, Integer, DateTime, Text, Boolean, ForeignKey,
2018-11-26 19:33:41 -05:00
Table, MetaData,
2020-07-01 10:20:14 -04:00
select, and_, true)
2018-11-26 19:33:41 -05:00
from sqlalchemy.engine.base import Engine
from mautrix.types import UserID, RoomID
Subscription = NamedTuple("Subscription", feed_id=int, room_id=RoomID, user_id=UserID,
2020-07-01 10:20:14 -04:00
notification_template=Template, send_notice=bool)
Feed = NamedTuple("Feed", id=int, url=str, title=str, subtitle=str, link=str, next_retry=int,
error_count=int, subscriptions=List[Subscription])
2018-11-26 19:33:41 -05:00
Entry = NamedTuple("Entry", feed_id=int, id=str, date=datetime, title=str, summary=str, link=str)
class Database:
db: Engine
feed: Table
subscription: Table
entry: Table
version: Table
def __init__(self, db: Engine) -> None:
self.db = db
metadata = MetaData()
self.feed = Table("feed", metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("url", Text, nullable=False, unique=True),
Column("title", Text, nullable=False),
Column("subtitle", Text, nullable=False),
Column("link", Text, nullable=False),
Column("next_retry", Integer, nullable=False),
Column("error_count", Integer, nullable=False))
2018-11-26 19:33:41 -05:00
self.subscription = Table("subscription", metadata,
Column("feed_id", Integer, ForeignKey("feed.id"),
primary_key=True),
Column("room_id", String(255), primary_key=True),
Column("user_id", String(255), nullable=False),
2020-07-01 10:20:14 -04:00
Column("notification_template", String(255), nullable=True),
Column("send_notice", Boolean, nullable=False,
server_default=true()))
2018-11-26 19:33:41 -05:00
self.entry = Table("entry", metadata,
Column("feed_id", Integer, ForeignKey("feed.id"), primary_key=True),
Column("id", String(255), primary_key=True),
Column("date", DateTime, nullable=False),
Column("title", Text, nullable=False),
Column("summary", Text, nullable=False),
Column("link", Text, nullable=False))
self.version = Table("version", metadata,
Column("version", Integer, primary_key=True))
2020-07-01 10:20:14 -04:00
self.upgrade()
def upgrade(self) -> None:
self.db.execute("CREATE TABLE IF NOT EXISTS version (version INTEGER PRIMARY KEY)")
2020-07-01 10:20:14 -04:00
try:
version, = next(self.db.execute(select([self.version.c.version])))
except (StopIteration, IndexError):
version = 0
if version == 0:
self.db.execute("""CREATE TABLE IF NOT EXISTS feed (
id INTEGER NOT NULL,
url TEXT NOT NULL,
title TEXT NOT NULL,
subtitle TEXT NOT NULL,
link TEXT NOT NULL,
PRIMARY KEY (id),
UNIQUE (url)
)""")
self.db.execute("""CREATE TABLE IF NOT EXISTS subscription (
feed_id INTEGER NOT NULL,
room_id VARCHAR(255) NOT NULL,
user_id VARCHAR(255) NOT NULL,
notification_template VARCHAR(255),
PRIMARY KEY (feed_id, room_id),
FOREIGN KEY(feed_id) REFERENCES feed (id)
)""")
self.db.execute("""CREATE TABLE IF NOT EXISTS entry (
feed_id INTEGER NOT NULL,
id VARCHAR(255) NOT NULL,
date DATETIME NOT NULL,
title TEXT NOT NULL,
summary TEXT NOT NULL,
link TEXT NOT NULL,
PRIMARY KEY (feed_id, id),
FOREIGN KEY(feed_id) REFERENCES feed (id)
)""")
version = 1
if version == 1:
self.db.execute("ALTER TABLE subscription ADD COLUMN send_notice BOOLEAN DEFAULT true")
version = 2
if version == 2:
self.db.execute("ALTER TABLE feed ADD COLUMN next_retry BIGINT DEFAULT 0")
self.db.execute("ALTER TABLE feed ADD COLUMN error_count BIGINT DEFAULT 0")
version = 3
2020-07-01 10:20:14 -04:00
self.db.execute(self.version.delete())
self.db.execute(self.version.insert().values(version=version))
2018-11-26 19:33:41 -05:00
def get_feeds(self) -> Iterable[Feed]:
rows = self.db.execute(select([self.feed,
self.subscription.c.room_id,
self.subscription.c.user_id,
2020-07-01 10:20:14 -04:00
self.subscription.c.notification_template,
self.subscription.c.send_notice])
2018-11-26 19:33:41 -05:00
.where(self.subscription.c.feed_id == self.feed.c.id))
map: Dict[int, Feed] = {}
for row in rows:
(feed_id, url, title, subtitle, link, next_retry, error_count,
2020-07-01 10:20:14 -04:00
room_id, user_id, notification_template, send_notice) = row
map.setdefault(feed_id, Feed(feed_id, url, title, subtitle, link, next_retry,
error_count, subscriptions=[]))
map[feed_id].subscriptions.append(
Subscription(feed_id=feed_id, room_id=room_id, user_id=user_id,
2020-07-01 10:20:14 -04:00
notification_template=Template(notification_template),
send_notice=send_notice))
2018-11-26 19:33:41 -05:00
return map.values()
def get_feeds_by_room(self, room_id: RoomID) -> Iterable[Tuple[Feed, UserID]]:
return ((Feed(feed_id, url, title, subtitle, link, next_retry, error_count,
subscriptions=[]),
user_id)
for (feed_id, url, title, subtitle, link, next_retry, error_count, user_id) in
self.db.execute(select([self.feed, self.subscription.c.user_id])
2018-11-26 19:33:41 -05:00
.where(and_(self.subscription.c.room_id == room_id,
self.subscription.c.feed_id == self.feed.c.id))))
def get_rooms_by_feed(self, feed_id: int) -> Iterable[RoomID]:
return (row[0] for row in
self.db.execute(select([self.subscription.c.room_id])
.where(self.subscription.c.feed_id == feed_id)))
def get_entries(self, feed_id: int) -> Iterable[Entry]:
return (Entry(*row) for row in
self.db.execute(select([self.entry]).where(self.entry.c.feed_id == feed_id)))
2020-08-02 20:03:19 -04:00
def add_entries(self, entries: Iterable[Entry], override_feed_id: Optional[int] = None) -> None:
2018-11-26 19:33:41 -05:00
if not entries:
return
2020-08-02 20:03:19 -04:00
entries = [entry._asdict() for entry in entries]
if override_feed_id is not None:
for entry in entries:
entry["feed_id"] = override_feed_id
self.db.execute(self.entry.insert(), entries)
2018-11-26 19:33:41 -05:00
def get_feed_by_url(self, url: str) -> Optional[Feed]:
rows = self.db.execute(select([self.feed]).where(self.feed.c.url == url))
try:
row = next(rows)
return Feed(*row, subscriptions=[])
except (ValueError, StopIteration):
2018-11-26 19:33:41 -05:00
return None
def get_feed_by_id(self, feed_id: int) -> Optional[Feed]:
rows = self.db.execute(select([self.feed]).where(self.feed.c.id == feed_id))
2018-11-26 19:33:41 -05:00
try:
row = next(rows)
return Feed(*row, subscriptions=[])
except (ValueError, StopIteration):
2018-11-26 19:33:41 -05:00
return None
def get_subscription(self, feed_id: int, room_id: RoomID) -> Tuple[Optional[Subscription],
Optional[Feed]]:
tbl = self.subscription
rows = self.db.execute(select([self.feed, tbl.c.room_id, tbl.c.user_id,
2020-07-01 10:20:14 -04:00
tbl.c.notification_template, tbl.c.send_notice])
.where(and_(tbl.c.feed_id == feed_id, tbl.c.room_id == room_id,
self.feed.c.id == feed_id)))
try:
(feed_id, url, title, subtitle, link, next_retry, error_count,
2020-07-01 10:20:14 -04:00
room_id, user_id, template, send_notice) = next(rows)
notification_template = Template(template)
2020-07-01 10:20:14 -04:00
return (Subscription(feed_id, room_id, user_id, notification_template, send_notice)
if room_id else None,
Feed(feed_id, url, title, subtitle, link, next_retry, error_count, []))
except (ValueError, StopIteration):
2020-05-20 18:10:25 -04:00
return None, None
2019-11-21 16:50:56 -05:00
def update_room_id(self, old: RoomID, new: RoomID) -> None:
self.db.execute(self.subscription.update()
.where(self.subscription.c.room_id == old)
.values(room_id=new))
2020-08-02 20:03:19 -04:00
def create_feed(self, info: Feed) -> Feed:
res = self.db.execute(self.feed.insert().values(url=info.url, title=info.title,
subtitle=info.subtitle, link=info.link,
next_retry=info.next_retry))
2020-08-02 20:03:19 -04:00
return Feed(id=res.inserted_primary_key[0], url=info.url, title=info.title,
subtitle=info.subtitle, link=info.link, next_retry=info.next_retry,
error_count=info.error_count, subscriptions=[])
def set_backoff(self, info: Feed, error_count: int, next_retry: int) -> None:
self.db.execute(self.feed.update()
.where(self.feed.c.id == info.id)
.values(error_count=error_count, next_retry=next_retry))
2018-11-26 19:33:41 -05:00
def subscribe(self, feed_id: int, room_id: RoomID, user_id: UserID) -> None:
self.db.execute(self.subscription.insert().values(
feed_id=feed_id, room_id=room_id, user_id=user_id,
notification_template="New post in $feed_title: [$title]($link)"))
2018-11-26 19:33:41 -05:00
def unsubscribe(self, feed_id: int, room_id: RoomID) -> None:
tbl = self.subscription
self.db.execute(tbl.delete().where(and_(tbl.c.feed_id == feed_id,
tbl.c.room_id == room_id)))
def update_template(self, feed_id: int, room_id: RoomID, template: str) -> None:
tbl = self.subscription
self.db.execute(tbl.update()
.where(and_(tbl.c.feed_id == feed_id, tbl.c.room_id == room_id))
.values(notification_template=template))
2020-07-01 10:20:14 -04:00
def set_send_notice(self, feed_id: int, room_id: RoomID, send_notice: bool) -> None:
tbl = self.subscription
self.db.execute(tbl.update()
.where(and_(tbl.c.feed_id == feed_id, tbl.c.room_id == room_id))
.values(send_notice=send_notice))