Add permission checks and add support for custom notification templates

This commit is contained in:
Tulir Asokan 2018-11-28 00:41:22 +02:00
parent 2ee880d9cb
commit 778be0aa19
2 changed files with 131 additions and 40 deletions

View File

@ -13,19 +13,22 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type, List, Any
from typing import Type, List, Any, Dict, Tuple, Awaitable
from datetime import datetime
from time import mktime
from time import mktime, time
from string import Template
import asyncio
import aiohttp
import commonmark
import feedparser
from maubot import Plugin, MessageEvent
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
from mautrix.types import EventType, MessageType, RoomID
from mautrix.types import (EventType, MessageType, RoomID, EventID, PowerLevelStateEventContent,
TextMessageEventContent, Format)
from .db import Database, Feed, Entry
from .db import Database, Feed, Entry, Subscription
class Config(BaseProxyConfig):
@ -38,6 +41,7 @@ class RSSBot(Plugin):
db: Database
poll_task: asyncio.Future
http: aiohttp.ClientSession
power_level_cache: Dict[RoomID, Tuple[int, PowerLevelStateEventContent]]
@classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]:
@ -48,7 +52,7 @@ class RSSBot(Plugin):
self.db = Database(self.request_db_engine())
self.client.add_event_handler(self.event_handler, EventType.ROOM_MESSAGE)
self.http = self.client.api.session
self.power_level_cache = {}
self.poll_task = asyncio.ensure_future(self.poll_feeds(), loop=self.loop)
async def stop(self) -> None:
@ -63,12 +67,24 @@ class RSSBot(Plugin):
except Exception:
self.log.exception("Fatal error while polling feeds")
async def _broadcast(self, feed: Feed, entry: Entry, subscriptions: List[RoomID]) -> None:
text = f"New post in {feed.title}: {entry.title} ({entry.link})"
html = f"New post in {feed.title}: <a href='{entry.link}'>{entry.title}</a>"
def _send(self, feed: Feed, entry: Entry, template: Template, room_id: RoomID) -> Awaitable[EventID]:
message = template.safe_substitute({
"feed_url": feed.url,
"feed_title": feed.title,
"feed_subtitle": feed.subtitle,
"feed_link": feed.link,
**entry._asdict(),
})
content = TextMessageEventContent(msgtype=MessageType.NOTICE,
body=message,
format=Format.HTML,
formatted_body=commonmark.commonmark(message))
return self.client.send_message(room_id, content)
async def _broadcast(self, feed: Feed, entry: Entry, subscriptions: List[Subscription]) -> None:
spam_sleep = self.config["spam_sleep"]
tasks = [self.client.send_notice(room_id, text=text, html=html) for room_id in
subscriptions]
tasks = [self._send(feed, entry, sub.notification_template, sub.room_id)
for sub in subscriptions]
if spam_sleep >= 0:
for task in tasks:
await task
@ -80,9 +96,8 @@ class RSSBot(Plugin):
subs = self.db.get_feeds()
if not subs:
return
responses = await asyncio.gather(*[self.http.get(feed.url) for feed in subs], loop=self.loop)
texts = await asyncio.gather(*[resp.text() for resp in responses], loop=self.loop)
for feed, data in zip(subs, texts):
datas = await asyncio.gather(*[self.read_feed(feed.url) for feed in subs], loop=self.loop)
for feed, data in zip(subs, datas):
parsed_data = feedparser.parse(data)
entries = parsed_data.entries
new_entries = {entry.id: entry for entry in self.find_entries(feed.id, entries)}
@ -101,10 +116,10 @@ class RSSBot(Plugin):
self.log.exception("Error while polling feeds")
await asyncio.sleep(self.config["update_interval"] * 60, loop=self.loop)
async def read_feed(self, url: str):
async def read_feed(self, url: str) -> str:
resp = await self.http.get(url)
content = await resp.text()
return feedparser.parse(content)
return content
@staticmethod
def get_date(entry: Any) -> datetime:
@ -129,6 +144,24 @@ class RSSBot(Plugin):
link=entry.link,
) for entry in entries]
async def get_power_levels(self, room_id: RoomID) -> PowerLevelStateEventContent:
try:
expiry, levels = self.power_level_cache[room_id]
if expiry < int(time()):
return levels
except KeyError:
pass
levels = await self.client.get_state_event(room_id, EventType.ROOM_POWER_LEVELS)
self.power_level_cache[room_id] = (int(time()) + 5 * 60, levels)
return levels
async def can_manage(self, evt: MessageEvent) -> bool:
levels = await self.get_power_levels(evt.room_id)
if levels.get_user_level(evt.sender) < levels.state_default:
await evt.reply("You don't the permission to manage the subscriptions of this room.")
return False
return True
async def event_handler(self, evt: MessageEvent) -> None:
if evt.content.msgtype != MessageType.TEXT or not evt.content.body.startswith("!rss"):
return
@ -136,13 +169,15 @@ class RSSBot(Plugin):
args = evt.content.body[len("!rss "):].split(" ")
cmd, args = args[0].lower(), args[1:]
if cmd == "sub" or cmd == "subscribe":
if len(args) == 0:
if not await self.can_manage(evt):
return
elif len(args) == 0:
await evt.reply(f"**Usage:** !rss {cmd} <feed URL>")
return
url = " ".join(args)
feed = self.db.get_feed_by_url(url)
if not feed:
metadata = await self.read_feed(url)
metadata = feedparser.parse(await self.read_feed(url))
if metadata.bozo:
await evt.reply("That doesn't look like a valid feed.")
return
@ -154,19 +189,43 @@ class RSSBot(Plugin):
self.db.subscribe(feed.id, evt.room_id, evt.sender)
await evt.reply(f"Subscribed to feed ID {feed.id}: [{feed.title}]({feed.url})")
elif cmd == "unsub" or cmd == "unsubscribe":
if len(args) == 0:
if not await self.can_manage(evt):
return
try:
feed_id = int(args[0])
except (ValueError, IndexError):
await evt.reply(f"**Usage:** !rss {cmd} <feed ID>")
return
feed = self.db.get_feed_by_id_or_url(" ".join(args))
if not feed:
await evt.reply("Feed not found")
sub, feed = self.db.get_subscription(feed_id, evt.room_id)
if not sub:
await evt.reply("This room is not subscribed to that feed")
return
self.db.unsubscribe(feed.id, evt.room_id)
await evt.reply(f"Unsubscribed from feed ID {feed.id}: [{feed.title}]({feed.url})")
elif cmd == "template" or cmd == "tpl":
if not await self.can_manage(evt):
return
try:
feed_id = int(args[0])
template = " ".join(args[1:])
except (ValueError, IndexError):
await evt.reply(f"**Usage:** !rss {cmd} <feed ID> <new template>")
return
sub, feed = self.db.get_subscription(feed_id, evt.room_id)
if not sub:
await evt.reply("This room is not subscribed to that feed")
return
self.db.update_template(feed.id, evt.room_id, template)
sample_entry = Entry(feed.id, "SAMPLE", datetime.now(), "Sample entry",
"This is a sample entry to demonstrate your new template",
"http://example.com")
await evt.reply(f"Template for feed ID {feed.id} updated. Sample notification:")
await self._send(feed, sample_entry, Template(template), sub.room_id)
elif cmd == "subs" or cmd == "subscriptions":
subscriptions = self.db.get_feeds_by_room(evt.room_id)
await evt.reply("**Subscriptions in this room:**\n\n"
+ "\n".join(f"* {feed.id} - [{feed.title}]({feed.url})"
for feed in subscriptions))
+ "\n".join(f"* {feed.id} - [{feed.title}]({feed.url}) (subscribed by "
f"[{subscriber}](https://matrix.to/#/{subscriber}))"
for feed, subscriber in subscriptions))
else:
await evt.reply("**Usage:** !rss <sub/unsub/subs> [params...]")

View File

@ -13,18 +13,21 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Iterable, NamedTuple, List, Optional, Dict
from typing import Iterable, NamedTuple, List, Optional, Dict, Tuple
from datetime import datetime
from string import Template
from sqlalchemy import (Column, String, Integer, DateTime, Text, ForeignKey,
Table, MetaData,
select, and_, or_)
select, and_)
from sqlalchemy.engine.base import Engine
from mautrix.types import UserID, RoomID
Subscription = NamedTuple("Subscription", feed_id=int, room_id=str, user_id=str,
notification_template=Template)
Feed = NamedTuple("Feed", id=int, url=str, title=str, subtitle=str, link=str,
subscriptions=List[RoomID])
subscriptions=List[Subscription])
Entry = NamedTuple("Entry", feed_id=int, id=str, date=datetime, title=str, summary=str, link=str)
@ -48,7 +51,8 @@ class Database:
Column("feed_id", Integer, ForeignKey("feed.id"),
primary_key=True),
Column("room_id", String(255), primary_key=True),
Column("user_id", String(255), nullable=False))
Column("user_id", String(255), nullable=False),
Column("notification_template", String(255), nullable=True))
self.entry = Table("entry", metadata,
Column("feed_id", Integer, ForeignKey("feed.id"), primary_key=True),
Column("id", String(255), primary_key=True),
@ -61,18 +65,24 @@ class Database:
metadata.create_all(db)
def get_feeds(self) -> Iterable[Feed]:
rows = self.db.execute(select([self.feed, self.subscription.c.room_id])
rows = self.db.execute(select([self.feed,
self.subscription.c.room_id,
self.subscription.c.user_id,
self.subscription.c.notification_template])
.where(self.subscription.c.feed_id == self.feed.c.id))
map: Dict[int, Feed] = {}
for row in rows:
feed_id, url, title, subtitle, link, room_id = row
feed_id, url, title, subtitle, link, room_id, user_id, notification_template = row
map.setdefault(feed_id, Feed(feed_id, url, title, subtitle, link, subscriptions=[]))
map[feed_id].subscriptions.append(room_id)
map[feed_id].subscriptions.append(
Subscription(feed_id=feed_id, room_id=room_id, user_id=user_id,
notification_template=Template(notification_template)))
return map.values()
def get_feeds_by_room(self, room_id: RoomID) -> Iterable[Feed]:
return (Feed(*row, subscriptions=[]) for row in
self.db.execute(select([self.feed])
def get_feeds_by_room(self, room_id: RoomID) -> Iterable[Tuple[Feed, UserID]]:
return ((Feed(feed_id, url, title, subtitle, link, subscriptions=[]), user_id)
for (feed_id, url, title, subtitle, link, user_id) in
self.db.execute(select([self.feed, self.subscription.c.user_id])
.where(and_(self.subscription.c.room_id == room_id,
self.subscription.c.feed_id == self.feed.c.id))))
@ -95,18 +105,33 @@ class Database:
try:
row = next(rows)
return Feed(*row, subscriptions=[])
except (StopIteration, IndexError):
except (ValueError, StopIteration):
return None
def get_feed_by_id_or_url(self, identifier: str) -> Optional[Feed]:
rows = self.db.execute(select([self.feed]).where(
or_(self.feed.c.url == identifier, self.feed.c.id == identifier)))
def get_feed_by_id(self, feed_id: int) -> Optional[Feed]:
rows = self.db.execute(select([self.feed]).where(self.feed.c.id == feed_id))
try:
row = next(rows)
return Feed(*row, subscriptions=[])
except (StopIteration, IndexError):
except (ValueError, StopIteration):
return None
def get_subscription(self, feed_id: int, room_id: RoomID) -> Tuple[Optional[Subscription],
Optional[Feed]]:
tbl = self.subscription
rows = self.db.execute(select([self.feed, tbl.c.room_id, tbl.c.user_id,
tbl.c.notification_template])
.where(and_(tbl.c.feed_id == feed_id, tbl.c.room_id == room_id,
self.feed.c.id == feed_id)))
try:
feed_id, url, title, subtitle, link, room_id, user_id, template = next(rows)
notification_template = Template(template)
return (Subscription(feed_id, room_id, user_id, notification_template)
if room_id else None,
Feed(feed_id, url, title, subtitle, link, []))
except (ValueError, StopIteration):
return (None, None)
def create_feed(self, url: str, title: str, subtitle: str, link: str) -> Feed:
res = self.db.execute(self.feed.insert().values(url=url, title=title, subtitle=subtitle,
link=link))
@ -114,10 +139,17 @@ class Database:
link=link, subscriptions=[])
def subscribe(self, feed_id: int, room_id: RoomID, user_id: UserID) -> None:
self.db.execute(self.subscription.insert().values(feed_id=feed_id, room_id=room_id,
user_id=user_id))
self.db.execute(self.subscription.insert().values(
feed_id=feed_id, room_id=room_id, user_id=user_id,
notification_template="New post in $feed_title: [$title]($link)"))
def unsubscribe(self, feed_id: int, room_id: RoomID) -> None:
tbl = self.subscription
self.db.execute(tbl.delete().where(and_(tbl.c.feed_id == feed_id,
tbl.c.room_id == room_id)))
def update_template(self, feed_id: int, room_id: RoomID, template: str) -> None:
tbl = self.subscription
self.db.execute(tbl.update()
.where(and_(tbl.c.feed_id == feed_id, tbl.c.room_id == room_id))
.values(notification_template=template))