From b3d4482c3c700a1a47bbcf5f9b2e329d5fd1c066 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 29 Nov 2018 00:40:38 +0200 Subject: [PATCH] Fix bugs and make command prefix configurable --- base-config.yaml | 2 + rss/bot.py | 185 ++++++++++++++++++++++++++++++----------------- rss/db.py | 2 +- 3 files changed, 123 insertions(+), 66 deletions(-) diff --git a/base-config.yaml b/base-config.yaml index 8b1d367..bc2a105 100644 --- a/base-config.yaml +++ b/base-config.yaml @@ -3,3 +3,5 @@ update_interval: 60 # The time to sleep between send requests when broadcasting a new feed entry. # Set to 0 to disable sleep or -1 to run all requests asynchronously at once. spam_sleep: 2 +# The prefix for all commands +command_prefix: "!rss" diff --git a/rss/bot.py b/rss/bot.py index a9b8bdf..5041a4c 100644 --- a/rss/bot.py +++ b/rss/bot.py @@ -13,7 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Type, List, Any, Dict, Tuple, Awaitable +from typing import Type, List, Any, Dict, Tuple, Awaitable, Callable from datetime import datetime from time import mktime, time from string import Template @@ -33,6 +33,18 @@ class Config(BaseProxyConfig): def do_update(self, helper: ConfigUpdateHelper) -> None: helper.copy("update_interval") helper.copy("spam_sleep") + helper.copy("command_prefix") + + +CommandHandler = Callable[[MessageEvent, str, List[str]], Awaitable[None]] + + +def command_handler(*args: str) -> Callable[[CommandHandler], CommandHandler]: + def wrapper(func: CommandHandler) -> CommandHandler: + setattr(func, "commands", args) + return func + + return wrapper class RSSBot(Plugin): @@ -40,11 +52,17 @@ class RSSBot(Plugin): poll_task: asyncio.Future http: aiohttp.ClientSession power_level_cache: Dict[RoomID, Tuple[int, PowerLevelStateEventContent]] + cmd_prefix: str + commands: Dict[str, CommandHandler] @classmethod def get_config_class(cls) -> Type[BaseProxyConfig]: return Config + def on_external_config_update(self) -> None: + self.config.load_and_update() + self.cmd_prefix = self.config["command_prefix"] + async def start(self) -> None: self.config.load_and_update() self.db = Database(self.request_db_engine()) @@ -52,6 +70,14 @@ class RSSBot(Plugin): self.http = self.client.api.session self.power_level_cache = {} self.poll_task = asyncio.ensure_future(self.poll_feeds(), loop=self.loop) + self.cmd_prefix = self.config["command_prefix"] + + self.commands = {} + for attr_name in dir(self): + if attr_name.startswith("command_"): + handler = getattr(self, attr_name) + for alias in handler.commands: + self.commands[alias] = handler async def stop(self) -> None: self.client.remove_event_handler(self.event_handler, EventType.ROOM_MESSAGE) @@ -106,12 +132,17 @@ class RSSBot(Plugin): while True: try: await self._poll_once() + except asyncio.CancelledError: + self.log.debug("Polling stopped") except Exception: self.log.exception("Error while polling feeds") await asyncio.sleep(self.config["update_interval"] * 60, loop=self.loop) async def read_feed(self, url: str) -> str: - resp = await self.http.get(url) + try: + resp = await self.http.get(url) + except aiohttp.client_exceptions.ClientError: + return "" content = await resp.text() return content @@ -156,70 +187,94 @@ class RSSBot(Plugin): return False return True + @command_handler("subscribe", "sub", "s") + async def command_subscribe(self, evt: MessageEvent, cmd: str, args: List[str]) -> None: + if not await self.can_manage(evt): + return + elif len(args) == 0: + await evt.reply(f"**Usage:** `{self.cmd_prefix} {cmd} `") + return + url = " ".join(args) + feed = self.db.get_feed_by_url(url) + if not feed: + metadata = feedparser.parse(await self.read_feed(url)) + if metadata.bozo: + await evt.reply("That doesn't look like a valid feed.") + return + channel = metadata.get("channel", {}) + feed = self.db.create_feed(url, channel.get("title", url), + channel.get("description", ""), + channel.get("link", "")) + self.db.add_entries(self.find_entries(feed.id, metadata.entries)) + self.db.subscribe(feed.id, evt.room_id, evt.sender) + await evt.reply(f"Subscribed to feed ID {feed.id}: [{feed.title}]({feed.url})") + + @command_handler("unsubscribe", "unsub", "u") + async def command_unsubscribe(self, evt: MessageEvent, cmd: str, args: List[str]) -> None: + if not await self.can_manage(evt): + return + try: + feed_id = int(args[0]) + except (ValueError, IndexError): + await evt.reply(f"**Usage:** `{self.cmd_prefix} {cmd} `") + return + sub, feed = self.db.get_subscription(feed_id, evt.room_id) + if not sub: + await evt.reply("This room is not subscribed to that feed") + return + self.db.unsubscribe(feed.id, evt.room_id) + await evt.reply(f"Unsubscribed from feed ID {feed.id}: [{feed.title}]({feed.url})") + + @command_handler("template", "tpl", "t") + async def command_template(self, evt: MessageEvent, cmd: str, args: List[str]) -> None: + if not await self.can_manage(evt): + return + try: + feed_id = int(args[0]) + template = " ".join(args[1:]) + except (ValueError, IndexError): + await evt.reply(f"**Usage:** `{self.cmd_prefix} {cmd} `") + return + sub, feed = self.db.get_subscription(feed_id, evt.room_id) + if not sub: + await evt.reply("This room is not subscribed to that feed") + return + self.db.update_template(feed.id, evt.room_id, template) + sample_entry = Entry(feed.id, "SAMPLE", datetime.now(), "Sample entry", + "This is a sample entry to demonstrate your new template", + "http://example.com") + await evt.reply(f"Template for feed ID {feed.id} updated. Sample notification:") + await self._send(feed, sample_entry, Template(template), sub.room_id) + + @command_handler("subscriptions", "subs", "list", "ls") + async def command_subscriptions(self, evt: MessageEvent, _1: str, _2: List[str]) -> None: + subscriptions = self.db.get_feeds_by_room(evt.room_id) + await evt.reply("**Subscriptions in this room:**\n\n" + + "\n".join(f"* {feed.id} - [{feed.title}]({feed.url}) (subscribed by " + f"[{subscriber}](https://matrix.to/#/{subscriber}))" + for feed, subscriber in subscriptions)) + + @command_handler("help") + async def command_help(self, evt: MessageEvent, cmd: str, _2: List[str]) -> None: + await evt.reply( + ("Unknown command. " if cmd != "help" and cmd != "" else "") + + "Available commands:\n\n" + f"* {self.cmd_prefix} **subscribe** __ - Subscribe to a feed\n" + f"* {self.cmd_prefix} **unsubscribe** __ - Unsubscribe from a feed\n" + f"* {self.cmd_prefix} **template** __ __ - Change the " + f"notification template for a feed\n" + f"* {self.cmd_prefix} **subscriptions** - List subscriptions in current room\n" + f"* {self.cmd_prefix} **help** - Print this message") + async def event_handler(self, evt: MessageEvent) -> None: - if evt.content.msgtype != MessageType.TEXT or not evt.content.body.startswith("!rss"): + if evt.content.msgtype != MessageType.TEXT or not evt.content.body.startswith( + self.cmd_prefix): return - args = evt.content.body[len("!rss "):].split(" ") + args = evt.content.body[len(self.cmd_prefix) + 1:].split(" ") cmd, args = args[0].lower(), args[1:] - if cmd == "sub" or cmd == "subscribe": - if not await self.can_manage(evt): - return - elif len(args) == 0: - await evt.reply(f"**Usage:** !rss {cmd} ") - return - url = " ".join(args) - feed = self.db.get_feed_by_url(url) - if not feed: - metadata = feedparser.parse(await self.read_feed(url)) - if metadata.bozo: - await evt.reply("That doesn't look like a valid feed.") - return - channel = metadata.get("channel", {}) - feed = self.db.create_feed(url, channel.get("title", url), - channel.get("description", ""), - channel.get("link", "")) - self.db.add_entries(self.find_entries(feed.id, metadata.entries)) - self.db.subscribe(feed.id, evt.room_id, evt.sender) - await evt.reply(f"Subscribed to feed ID {feed.id}: [{feed.title}]({feed.url})") - elif cmd == "unsub" or cmd == "unsubscribe": - if not await self.can_manage(evt): - return - try: - feed_id = int(args[0]) - except (ValueError, IndexError): - await evt.reply(f"**Usage:** !rss {cmd} ") - return - sub, feed = self.db.get_subscription(feed_id, evt.room_id) - if not sub: - await evt.reply("This room is not subscribed to that feed") - return - self.db.unsubscribe(feed.id, evt.room_id) - await evt.reply(f"Unsubscribed from feed ID {feed.id}: [{feed.title}]({feed.url})") - elif cmd == "template" or cmd == "tpl": - if not await self.can_manage(evt): - return - try: - feed_id = int(args[0]) - template = " ".join(args[1:]) - except (ValueError, IndexError): - await evt.reply(f"**Usage:** !rss {cmd} ") - return - sub, feed = self.db.get_subscription(feed_id, evt.room_id) - if not sub: - await evt.reply("This room is not subscribed to that feed") - return - self.db.update_template(feed.id, evt.room_id, template) - sample_entry = Entry(feed.id, "SAMPLE", datetime.now(), "Sample entry", - "This is a sample entry to demonstrate your new template", - "http://example.com") - await evt.reply(f"Template for feed ID {feed.id} updated. Sample notification:") - await self._send(feed, sample_entry, Template(template), sub.room_id) - elif cmd == "subs" or cmd == "subscriptions": - subscriptions = self.db.get_feeds_by_room(evt.room_id) - await evt.reply("**Subscriptions in this room:**\n\n" - + "\n".join(f"* {feed.id} - [{feed.title}]({feed.url}) (subscribed by " - f"[{subscriber}](https://matrix.to/#/{subscriber}))" - for feed, subscriber in subscriptions)) - else: - await evt.reply("**Usage:** !rss [params...]") + try: + handler = self.commands[cmd] + except KeyError: + handler = self.command_help + await handler(evt, cmd, args) diff --git a/rss/db.py b/rss/db.py index 89ee3c4..121f8ae 100644 --- a/rss/db.py +++ b/rss/db.py @@ -24,7 +24,7 @@ from sqlalchemy.engine.base import Engine from mautrix.types import UserID, RoomID -Subscription = NamedTuple("Subscription", feed_id=int, room_id=str, user_id=str, +Subscription = NamedTuple("Subscription", feed_id=int, room_id=RoomID, user_id=UserID, notification_template=Template) Feed = NamedTuple("Feed", id=int, url=str, title=str, subtitle=str, link=str, subscriptions=List[Subscription])