From 40ad7f0df362aeeed067083a98a35830bc17b6cd Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Mon, 28 Oct 2019 14:55:32 -0600 Subject: [PATCH] Update for new Spam Checker API --- synapse_antispam/mjolnir/antispam.py | 6 +++--- synapse_antispam/mjolnir/ban_list.py | 14 +++----------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/synapse_antispam/mjolnir/antispam.py b/synapse_antispam/mjolnir/antispam.py index 9fc46c4..ec0234c 100644 --- a/synapse_antispam/mjolnir/antispam.py +++ b/synapse_antispam/mjolnir/antispam.py @@ -21,12 +21,12 @@ from synapse.types import UserID logger = logging.getLogger("synapse.contrib." + __name__) class AntiSpam(object): - def __init__(self, config, hs): + def __init__(self, config, api): self.block_invites = config.get("block_invites", True) self.block_messages = config.get("block_messages", False) self.list_room_ids = config.get("ban_lists", []) self.rooms_to_lists = {} # type: Dict[str, BanList] - self.hs = hs + self.api = api # Now we build the ban lists so we can match them self.build_lists() @@ -41,7 +41,7 @@ class AntiSpam(object): def get_list_for_room(self, room_id): if room_id not in self.rooms_to_lists: - self.rooms_to_lists[room_id] = BanList(hs=self.hs, room_id=room_id) + self.rooms_to_lists[room_id] = BanList(api=self.api, room_id=room_id) return self.rooms_to_lists[room_id] def is_user_banned(self, user_id): diff --git a/synapse_antispam/mjolnir/ban_list.py b/synapse_antispam/mjolnir/ban_list.py index 540676a..58d317e 100644 --- a/synapse_antispam/mjolnir/ban_list.py +++ b/synapse_antispam/mjolnir/ban_list.py @@ -16,13 +16,12 @@ import logging from .list_rule import ListRule, ALL_RULE_TYPES, USER_RULE_TYPES, SERVER_RULE_TYPES, ROOM_RULE_TYPES from twisted.internet import defer -from synapse.storage.state import StateFilter logger = logging.getLogger("synapse.contrib." + __name__) class BanList(object): - def __init__(self, hs, room_id): - self.hs = hs + def __init__(self, api, room_id): + self.api = api self.room_id = room_id self.server_rules = [] self.user_rules = [] @@ -69,12 +68,5 @@ class BanList(object): elif event_type in SERVER_RULE_TYPES: self.server_rules.append(rule) - @defer.inlineCallbacks def get_relevant_state_events(self): - store = self.hs.get_datastore() - ev_filter = StateFilter.from_types([(t, None) for t in ALL_RULE_TYPES]) - state_ids = yield store.get_filtered_current_state_ids( - room_id=self.room_id, state_filter=ev_filter - ) - state = yield store.get_events(state_ids.values()) - return state.values() + return self.api.get_state_events_in_room(self.room_id, [(t, None) for t in ALL_RULE_TYPES])