Fix bugs and make command prefix configurable

This commit is contained in:
Tulir Asokan 2018-11-29 00:40:38 +02:00
parent 1ea4f71862
commit b3d4482c3c
3 changed files with 123 additions and 66 deletions

View File

@ -3,3 +3,5 @@ update_interval: 60
# The time to sleep between send requests when broadcasting a new feed entry.
# Set to 0 to disable sleep or -1 to run all requests asynchronously at once.
spam_sleep: 2
# The prefix for all commands
command_prefix: "!rss"

View File

@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type, List, Any, Dict, Tuple, Awaitable
from typing import Type, List, Any, Dict, Tuple, Awaitable, Callable
from datetime import datetime
from time import mktime, time
from string import Template
@ -33,6 +33,18 @@ class Config(BaseProxyConfig):
def do_update(self, helper: ConfigUpdateHelper) -> None:
helper.copy("update_interval")
helper.copy("spam_sleep")
helper.copy("command_prefix")
CommandHandler = Callable[[MessageEvent, str, List[str]], Awaitable[None]]
def command_handler(*args: str) -> Callable[[CommandHandler], CommandHandler]:
def wrapper(func: CommandHandler) -> CommandHandler:
setattr(func, "commands", args)
return func
return wrapper
class RSSBot(Plugin):
@ -40,11 +52,17 @@ class RSSBot(Plugin):
poll_task: asyncio.Future
http: aiohttp.ClientSession
power_level_cache: Dict[RoomID, Tuple[int, PowerLevelStateEventContent]]
cmd_prefix: str
commands: Dict[str, CommandHandler]
@classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]:
return Config
def on_external_config_update(self) -> None:
self.config.load_and_update()
self.cmd_prefix = self.config["command_prefix"]
async def start(self) -> None:
self.config.load_and_update()
self.db = Database(self.request_db_engine())
@ -52,6 +70,14 @@ class RSSBot(Plugin):
self.http = self.client.api.session
self.power_level_cache = {}
self.poll_task = asyncio.ensure_future(self.poll_feeds(), loop=self.loop)
self.cmd_prefix = self.config["command_prefix"]
self.commands = {}
for attr_name in dir(self):
if attr_name.startswith("command_"):
handler = getattr(self, attr_name)
for alias in handler.commands:
self.commands[alias] = handler
async def stop(self) -> None:
self.client.remove_event_handler(self.event_handler, EventType.ROOM_MESSAGE)
@ -106,12 +132,17 @@ class RSSBot(Plugin):
while True:
try:
await self._poll_once()
except asyncio.CancelledError:
self.log.debug("Polling stopped")
except Exception:
self.log.exception("Error while polling feeds")
await asyncio.sleep(self.config["update_interval"] * 60, loop=self.loop)
async def read_feed(self, url: str) -> str:
resp = await self.http.get(url)
try:
resp = await self.http.get(url)
except aiohttp.client_exceptions.ClientError:
return ""
content = await resp.text()
return content
@ -156,70 +187,94 @@ class RSSBot(Plugin):
return False
return True
@command_handler("subscribe", "sub", "s")
async def command_subscribe(self, evt: MessageEvent, cmd: str, args: List[str]) -> None:
if not await self.can_manage(evt):
return
elif len(args) == 0:
await evt.reply(f"**Usage:** `{self.cmd_prefix} {cmd} <feed URL>`")
return
url = " ".join(args)
feed = self.db.get_feed_by_url(url)
if not feed:
metadata = feedparser.parse(await self.read_feed(url))
if metadata.bozo:
await evt.reply("That doesn't look like a valid feed.")
return
channel = metadata.get("channel", {})
feed = self.db.create_feed(url, channel.get("title", url),
channel.get("description", ""),
channel.get("link", ""))
self.db.add_entries(self.find_entries(feed.id, metadata.entries))
self.db.subscribe(feed.id, evt.room_id, evt.sender)
await evt.reply(f"Subscribed to feed ID {feed.id}: [{feed.title}]({feed.url})")
@command_handler("unsubscribe", "unsub", "u")
async def command_unsubscribe(self, evt: MessageEvent, cmd: str, args: List[str]) -> None:
if not await self.can_manage(evt):
return
try:
feed_id = int(args[0])
except (ValueError, IndexError):
await evt.reply(f"**Usage:** `{self.cmd_prefix} {cmd} <feed ID>`")
return
sub, feed = self.db.get_subscription(feed_id, evt.room_id)
if not sub:
await evt.reply("This room is not subscribed to that feed")
return
self.db.unsubscribe(feed.id, evt.room_id)
await evt.reply(f"Unsubscribed from feed ID {feed.id}: [{feed.title}]({feed.url})")
@command_handler("template", "tpl", "t")
async def command_template(self, evt: MessageEvent, cmd: str, args: List[str]) -> None:
if not await self.can_manage(evt):
return
try:
feed_id = int(args[0])
template = " ".join(args[1:])
except (ValueError, IndexError):
await evt.reply(f"**Usage:** `{self.cmd_prefix} {cmd} <feed ID> <new template>`")
return
sub, feed = self.db.get_subscription(feed_id, evt.room_id)
if not sub:
await evt.reply("This room is not subscribed to that feed")
return
self.db.update_template(feed.id, evt.room_id, template)
sample_entry = Entry(feed.id, "SAMPLE", datetime.now(), "Sample entry",
"This is a sample entry to demonstrate your new template",
"http://example.com")
await evt.reply(f"Template for feed ID {feed.id} updated. Sample notification:")
await self._send(feed, sample_entry, Template(template), sub.room_id)
@command_handler("subscriptions", "subs", "list", "ls")
async def command_subscriptions(self, evt: MessageEvent, _1: str, _2: List[str]) -> None:
subscriptions = self.db.get_feeds_by_room(evt.room_id)
await evt.reply("**Subscriptions in this room:**\n\n"
+ "\n".join(f"* {feed.id} - [{feed.title}]({feed.url}) (subscribed by "
f"[{subscriber}](https://matrix.to/#/{subscriber}))"
for feed, subscriber in subscriptions))
@command_handler("help")
async def command_help(self, evt: MessageEvent, cmd: str, _2: List[str]) -> None:
await evt.reply(
("Unknown command. " if cmd != "help" and cmd != "" else "") +
"Available commands:\n\n"
f"* {self.cmd_prefix} **subscribe** _<feed URL>_ - Subscribe to a feed\n"
f"* {self.cmd_prefix} **unsubscribe** _<feed ID>_ - Unsubscribe from a feed\n"
f"* {self.cmd_prefix} **template** _<feed ID>_ _<new template>_ - Change the "
f"notification template for a feed\n"
f"* {self.cmd_prefix} **subscriptions** - List subscriptions in current room\n"
f"* {self.cmd_prefix} **help** - Print this message")
async def event_handler(self, evt: MessageEvent) -> None:
if evt.content.msgtype != MessageType.TEXT or not evt.content.body.startswith("!rss"):
if evt.content.msgtype != MessageType.TEXT or not evt.content.body.startswith(
self.cmd_prefix):
return
args = evt.content.body[len("!rss "):].split(" ")
args = evt.content.body[len(self.cmd_prefix) + 1:].split(" ")
cmd, args = args[0].lower(), args[1:]
if cmd == "sub" or cmd == "subscribe":
if not await self.can_manage(evt):
return
elif len(args) == 0:
await evt.reply(f"**Usage:** !rss {cmd} <feed URL>")
return
url = " ".join(args)
feed = self.db.get_feed_by_url(url)
if not feed:
metadata = feedparser.parse(await self.read_feed(url))
if metadata.bozo:
await evt.reply("That doesn't look like a valid feed.")
return
channel = metadata.get("channel", {})
feed = self.db.create_feed(url, channel.get("title", url),
channel.get("description", ""),
channel.get("link", ""))
self.db.add_entries(self.find_entries(feed.id, metadata.entries))
self.db.subscribe(feed.id, evt.room_id, evt.sender)
await evt.reply(f"Subscribed to feed ID {feed.id}: [{feed.title}]({feed.url})")
elif cmd == "unsub" or cmd == "unsubscribe":
if not await self.can_manage(evt):
return
try:
feed_id = int(args[0])
except (ValueError, IndexError):
await evt.reply(f"**Usage:** !rss {cmd} <feed ID>")
return
sub, feed = self.db.get_subscription(feed_id, evt.room_id)
if not sub:
await evt.reply("This room is not subscribed to that feed")
return
self.db.unsubscribe(feed.id, evt.room_id)
await evt.reply(f"Unsubscribed from feed ID {feed.id}: [{feed.title}]({feed.url})")
elif cmd == "template" or cmd == "tpl":
if not await self.can_manage(evt):
return
try:
feed_id = int(args[0])
template = " ".join(args[1:])
except (ValueError, IndexError):
await evt.reply(f"**Usage:** !rss {cmd} <feed ID> <new template>")
return
sub, feed = self.db.get_subscription(feed_id, evt.room_id)
if not sub:
await evt.reply("This room is not subscribed to that feed")
return
self.db.update_template(feed.id, evt.room_id, template)
sample_entry = Entry(feed.id, "SAMPLE", datetime.now(), "Sample entry",
"This is a sample entry to demonstrate your new template",
"http://example.com")
await evt.reply(f"Template for feed ID {feed.id} updated. Sample notification:")
await self._send(feed, sample_entry, Template(template), sub.room_id)
elif cmd == "subs" or cmd == "subscriptions":
subscriptions = self.db.get_feeds_by_room(evt.room_id)
await evt.reply("**Subscriptions in this room:**\n\n"
+ "\n".join(f"* {feed.id} - [{feed.title}]({feed.url}) (subscribed by "
f"[{subscriber}](https://matrix.to/#/{subscriber}))"
for feed, subscriber in subscriptions))
else:
await evt.reply("**Usage:** !rss <sub/unsub/subs> [params...]")
try:
handler = self.commands[cmd]
except KeyError:
handler = self.command_help
await handler(evt, cmd, args)

View File

@ -24,7 +24,7 @@ from sqlalchemy.engine.base import Engine
from mautrix.types import UserID, RoomID
Subscription = NamedTuple("Subscription", feed_id=int, room_id=str, user_id=str,
Subscription = NamedTuple("Subscription", feed_id=int, room_id=RoomID, user_id=UserID,
notification_template=Template)
Feed = NamedTuple("Feed", id=int, url=str, title=str, subtitle=str, link=str,
subscriptions=List[Subscription])