Add option to send posts as m.text

This commit is contained in:
Tulir Asokan 2020-07-01 17:20:14 +03:00
parent 391db1405f
commit 89ebfe7283
2 changed files with 102 additions and 18 deletions

View File

@ -40,6 +40,21 @@ class Config(BaseProxyConfig):
helper.copy("admins")
class BoolArgument(command.Argument):
def __init__(self, name: str, label: str = None, *, required: bool = False) -> None:
super().__init__(name, label, required=required, pass_raw=False)
def match(self, val: str, **kwargs) -> Tuple[str, Any]:
part = val.split(" ")[0].lower()
if part in ("f", "false", "n", "no", "0"):
res = False
elif part in ("t", "true", "y", "yes", "1"):
res = True
else:
raise ValueError("invalid boolean")
return val[len(part):], res
class RSSBot(Plugin):
db: Database
poll_task: asyncio.Future
@ -70,20 +85,18 @@ class RSSBot(Plugin):
except Exception:
self.log.exception("Fatal error while polling feeds")
def _send(self, feed: Feed, entry: Entry, template: Template, room_id: RoomID
) -> Awaitable[EventID]:
return self.client.send_markdown(room_id, template.safe_substitute({
def _send(self, feed: Feed, entry: Entry, sub: Subscription) -> Awaitable[EventID]:
return self.client.send_markdown(sub.room_id, sub.notification_template.safe_substitute({
"feed_url": feed.url,
"feed_title": feed.title,
"feed_subtitle": feed.subtitle,
"feed_link": feed.link,
**entry._asdict(),
}), msgtype=MessageType.NOTICE, allow_html=True)
}), msgtype=MessageType.NOTICE if sub.send_notice else MessageType.TEXT, allow_html=True)
async def _broadcast(self, feed: Feed, entry: Entry, subscriptions: List[Subscription]) -> None:
spam_sleep = self.config["spam_sleep"]
tasks = [self._send(feed, entry, sub.notification_template, sub.room_id)
for sub in subscriptions]
tasks = [self._send(feed, entry, sub) for sub in subscriptions]
if spam_sleep >= 0:
for task in tasks:
await task
@ -148,7 +161,7 @@ class RSSBot(Plugin):
def find_entries(cls, feed_id: int, entries: List[Any]) -> List[Entry]:
return [Entry(
feed_id=feed_id,
id=(getattr(entry, "id") or
id=(getattr(entry, "id", None) or
hashlib.sha1(" ".join([getattr(entry, "title", ""),
getattr(entry, "description", ""),
getattr(entry, "link", "")]).encode("utf-8")
@ -241,6 +254,21 @@ class RSSBot(Plugin):
await evt.reply(f"Template for feed ID {feed.id} updated. Sample notification:")
await self._send(feed, sample_entry, Template(template), sub.room_id)
@rss.subcommand("notice", aliases=("n",),
help="Set whether or not the bot should send updates as m.notice")
@command.argument("feed_id", "feed ID", parser=int)
@BoolArgument("setting", "true/false")
async def command_notice(self, evt: MessageEvent, feed_id: int, setting: bool) -> None:
if not await self.can_manage(evt):
return
sub, feed = self.db.get_subscription(feed_id, evt.room_id)
if not sub:
await evt.reply("This room is not subscribed to that feed")
return
self.db.set_send_notice(feed.id, evt.room_id, setting)
send_type = "m.notice" if setting else "m.text"
await evt.reply(f"Updates for feed ID {feed.id} will now be sent as `{send_type}`")
@rss.subcommand("subscriptions", aliases=("ls", "list", "subs"),
help="List the subscriptions in the current room.")
async def command_subscriptions(self, evt: MessageEvent) -> None:

View File

@ -17,15 +17,15 @@ from typing import Iterable, NamedTuple, List, Optional, Dict, Tuple
from datetime import datetime
from string import Template
from sqlalchemy import (Column, String, Integer, DateTime, Text, ForeignKey,
from sqlalchemy import (Column, String, Integer, DateTime, Text, Boolean, ForeignKey,
Table, MetaData,
select, and_)
select, and_, true)
from sqlalchemy.engine.base import Engine
from mautrix.types import UserID, RoomID
Subscription = NamedTuple("Subscription", feed_id=int, room_id=RoomID, user_id=UserID,
notification_template=Template)
notification_template=Template, send_notice=bool)
Feed = NamedTuple("Feed", id=int, url=str, title=str, subtitle=str, link=str,
subscriptions=List[Subscription])
Entry = NamedTuple("Entry", feed_id=int, id=str, date=datetime, title=str, summary=str, link=str)
@ -52,7 +52,9 @@ class Database:
primary_key=True),
Column("room_id", String(255), primary_key=True),
Column("user_id", String(255), nullable=False),
Column("notification_template", String(255), nullable=True))
Column("notification_template", String(255), nullable=True),
Column("send_notice", Boolean, nullable=False,
server_default=true()))
self.entry = Table("entry", metadata,
Column("feed_id", Integer, ForeignKey("feed.id"), primary_key=True),
Column("id", String(255), primary_key=True),
@ -62,21 +64,68 @@ class Database:
Column("link", Text, nullable=False))
self.version = Table("version", metadata,
Column("version", Integer, primary_key=True))
metadata.create_all(db)
self.upgrade()
def upgrade(self) -> None:
try:
version, = next(self.db.execute(select([self.version.c.version])))
except (StopIteration, IndexError):
version = 0
if version == 0:
self.db.execute("""CREATE TABLE IF NOT EXISTS feed (
id INTEGER NOT NULL,
url TEXT NOT NULL,
title TEXT NOT NULL,
subtitle TEXT NOT NULL,
link TEXT NOT NULL,
PRIMARY KEY (id),
UNIQUE (url)
)""")
self.db.execute("""CREATE TABLE IF NOT EXISTS version (
version INTEGER NOT NULL,
PRIMARY KEY (version)
)""")
self.db.execute("""CREATE TABLE IF NOT EXISTS subscription (
feed_id INTEGER NOT NULL,
room_id VARCHAR(255) NOT NULL,
user_id VARCHAR(255) NOT NULL,
notification_template VARCHAR(255),
PRIMARY KEY (feed_id, room_id),
FOREIGN KEY(feed_id) REFERENCES feed (id)
)""")
self.db.execute("""CREATE TABLE IF NOT EXISTS entry (
feed_id INTEGER NOT NULL,
id VARCHAR(255) NOT NULL,
date DATETIME NOT NULL,
title TEXT NOT NULL,
summary TEXT NOT NULL,
link TEXT NOT NULL,
PRIMARY KEY (feed_id, id),
FOREIGN KEY(feed_id) REFERENCES feed (id)
)""")
version = 1
if version == 1:
self.db.execute("ALTER TABLE subscription ADD COLUMN send_notice BOOLEAN DEFAULT true")
version = 2
self.db.execute(self.version.delete())
self.db.execute(self.version.insert().values(version=version))
def get_feeds(self) -> Iterable[Feed]:
rows = self.db.execute(select([self.feed,
self.subscription.c.room_id,
self.subscription.c.user_id,
self.subscription.c.notification_template])
self.subscription.c.notification_template,
self.subscription.c.send_notice])
.where(self.subscription.c.feed_id == self.feed.c.id))
map: Dict[int, Feed] = {}
for row in rows:
feed_id, url, title, subtitle, link, room_id, user_id, notification_template = row
(feed_id, url, title, subtitle, link,
room_id, user_id, notification_template, send_notice) = row
map.setdefault(feed_id, Feed(feed_id, url, title, subtitle, link, subscriptions=[]))
map[feed_id].subscriptions.append(
Subscription(feed_id=feed_id, room_id=room_id, user_id=user_id,
notification_template=Template(notification_template)))
notification_template=Template(notification_template),
send_notice=send_notice))
return map.values()
def get_feeds_by_room(self, room_id: RoomID) -> Iterable[Tuple[Feed, UserID]]:
@ -120,13 +169,14 @@ class Database:
Optional[Feed]]:
tbl = self.subscription
rows = self.db.execute(select([self.feed, tbl.c.room_id, tbl.c.user_id,
tbl.c.notification_template])
tbl.c.notification_template, tbl.c.send_notice])
.where(and_(tbl.c.feed_id == feed_id, tbl.c.room_id == room_id,
self.feed.c.id == feed_id)))
try:
feed_id, url, title, subtitle, link, room_id, user_id, template = next(rows)
(feed_id, url, title, subtitle, link,
room_id, user_id, template, send_notice) = next(rows)
notification_template = Template(template)
return (Subscription(feed_id, room_id, user_id, notification_template)
return (Subscription(feed_id, room_id, user_id, notification_template, send_notice)
if room_id else None,
Feed(feed_id, url, title, subtitle, link, []))
except (ValueError, StopIteration):
@ -158,3 +208,9 @@ class Database:
self.db.execute(tbl.update()
.where(and_(tbl.c.feed_id == feed_id, tbl.c.room_id == room_id))
.values(notification_template=template))
def set_send_notice(self, feed_id: int, room_id: RoomID, send_notice: bool) -> None:
tbl = self.subscription
self.db.execute(tbl.update()
.where(and_(tbl.c.feed_id == feed_id, tbl.c.room_id == room_id))
.values(send_notice=send_notice))