Add permission checks and add support for custom notification templates

This commit is contained in:
Tulir Asokan 2018-11-28 00:41:22 +02:00
parent 2ee880d9cb
commit 778be0aa19
2 changed files with 131 additions and 40 deletions

View file

@ -13,18 +13,21 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Iterable, NamedTuple, List, Optional, Dict
from typing import Iterable, NamedTuple, List, Optional, Dict, Tuple
from datetime import datetime
from string import Template
from sqlalchemy import (Column, String, Integer, DateTime, Text, ForeignKey,
Table, MetaData,
select, and_, or_)
select, and_)
from sqlalchemy.engine.base import Engine
from mautrix.types import UserID, RoomID
Subscription = NamedTuple("Subscription", feed_id=int, room_id=str, user_id=str,
notification_template=Template)
Feed = NamedTuple("Feed", id=int, url=str, title=str, subtitle=str, link=str,
subscriptions=List[RoomID])
subscriptions=List[Subscription])
Entry = NamedTuple("Entry", feed_id=int, id=str, date=datetime, title=str, summary=str, link=str)
@ -48,7 +51,8 @@ class Database:
Column("feed_id", Integer, ForeignKey("feed.id"),
primary_key=True),
Column("room_id", String(255), primary_key=True),
Column("user_id", String(255), nullable=False))
Column("user_id", String(255), nullable=False),
Column("notification_template", String(255), nullable=True))
self.entry = Table("entry", metadata,
Column("feed_id", Integer, ForeignKey("feed.id"), primary_key=True),
Column("id", String(255), primary_key=True),
@ -61,18 +65,24 @@ class Database:
metadata.create_all(db)
def get_feeds(self) -> Iterable[Feed]:
rows = self.db.execute(select([self.feed, self.subscription.c.room_id])
rows = self.db.execute(select([self.feed,
self.subscription.c.room_id,
self.subscription.c.user_id,
self.subscription.c.notification_template])
.where(self.subscription.c.feed_id == self.feed.c.id))
map: Dict[int, Feed] = {}
for row in rows:
feed_id, url, title, subtitle, link, room_id = row
feed_id, url, title, subtitle, link, room_id, user_id, notification_template = row
map.setdefault(feed_id, Feed(feed_id, url, title, subtitle, link, subscriptions=[]))
map[feed_id].subscriptions.append(room_id)
map[feed_id].subscriptions.append(
Subscription(feed_id=feed_id, room_id=room_id, user_id=user_id,
notification_template=Template(notification_template)))
return map.values()
def get_feeds_by_room(self, room_id: RoomID) -> Iterable[Feed]:
return (Feed(*row, subscriptions=[]) for row in
self.db.execute(select([self.feed])
def get_feeds_by_room(self, room_id: RoomID) -> Iterable[Tuple[Feed, UserID]]:
return ((Feed(feed_id, url, title, subtitle, link, subscriptions=[]), user_id)
for (feed_id, url, title, subtitle, link, user_id) in
self.db.execute(select([self.feed, self.subscription.c.user_id])
.where(and_(self.subscription.c.room_id == room_id,
self.subscription.c.feed_id == self.feed.c.id))))
@ -95,18 +105,33 @@ class Database:
try:
row = next(rows)
return Feed(*row, subscriptions=[])
except (StopIteration, IndexError):
except (ValueError, StopIteration):
return None
def get_feed_by_id_or_url(self, identifier: str) -> Optional[Feed]:
rows = self.db.execute(select([self.feed]).where(
or_(self.feed.c.url == identifier, self.feed.c.id == identifier)))
def get_feed_by_id(self, feed_id: int) -> Optional[Feed]:
rows = self.db.execute(select([self.feed]).where(self.feed.c.id == feed_id))
try:
row = next(rows)
return Feed(*row, subscriptions=[])
except (StopIteration, IndexError):
except (ValueError, StopIteration):
return None
def get_subscription(self, feed_id: int, room_id: RoomID) -> Tuple[Optional[Subscription],
Optional[Feed]]:
tbl = self.subscription
rows = self.db.execute(select([self.feed, tbl.c.room_id, tbl.c.user_id,
tbl.c.notification_template])
.where(and_(tbl.c.feed_id == feed_id, tbl.c.room_id == room_id,
self.feed.c.id == feed_id)))
try:
feed_id, url, title, subtitle, link, room_id, user_id, template = next(rows)
notification_template = Template(template)
return (Subscription(feed_id, room_id, user_id, notification_template)
if room_id else None,
Feed(feed_id, url, title, subtitle, link, []))
except (ValueError, StopIteration):
return (None, None)
def create_feed(self, url: str, title: str, subtitle: str, link: str) -> Feed:
res = self.db.execute(self.feed.insert().values(url=url, title=title, subtitle=subtitle,
link=link))
@ -114,10 +139,17 @@ class Database:
link=link, subscriptions=[])
def subscribe(self, feed_id: int, room_id: RoomID, user_id: UserID) -> None:
self.db.execute(self.subscription.insert().values(feed_id=feed_id, room_id=room_id,
user_id=user_id))
self.db.execute(self.subscription.insert().values(
feed_id=feed_id, room_id=room_id, user_id=user_id,
notification_template="New post in $feed_title: [$title]($link)"))
def unsubscribe(self, feed_id: int, room_id: RoomID) -> None:
tbl = self.subscription
self.db.execute(tbl.delete().where(and_(tbl.c.feed_id == feed_id,
tbl.c.room_id == room_id)))
def update_template(self, feed_id: int, room_id: RoomID, template: str) -> None:
tbl = self.subscription
self.db.execute(tbl.update()
.where(and_(tbl.c.feed_id == feed_id, tbl.c.room_id == room_id))
.values(notification_template=template))