Update for new Spam Checker API

This commit is contained in:
Travis Ralston 2019-10-28 14:55:32 -06:00
parent d313fad561
commit 40ad7f0df3
2 changed files with 6 additions and 14 deletions

View File

@ -21,12 +21,12 @@ from synapse.types import UserID
logger = logging.getLogger("synapse.contrib." + __name__) logger = logging.getLogger("synapse.contrib." + __name__)
class AntiSpam(object): class AntiSpam(object):
def __init__(self, config, hs): def __init__(self, config, api):
self.block_invites = config.get("block_invites", True) self.block_invites = config.get("block_invites", True)
self.block_messages = config.get("block_messages", False) self.block_messages = config.get("block_messages", False)
self.list_room_ids = config.get("ban_lists", []) self.list_room_ids = config.get("ban_lists", [])
self.rooms_to_lists = {} # type: Dict[str, BanList] self.rooms_to_lists = {} # type: Dict[str, BanList]
self.hs = hs self.api = api
# Now we build the ban lists so we can match them # Now we build the ban lists so we can match them
self.build_lists() self.build_lists()
@ -41,7 +41,7 @@ class AntiSpam(object):
def get_list_for_room(self, room_id): def get_list_for_room(self, room_id):
if room_id not in self.rooms_to_lists: if room_id not in self.rooms_to_lists:
self.rooms_to_lists[room_id] = BanList(hs=self.hs, room_id=room_id) self.rooms_to_lists[room_id] = BanList(api=self.api, room_id=room_id)
return self.rooms_to_lists[room_id] return self.rooms_to_lists[room_id]
def is_user_banned(self, user_id): def is_user_banned(self, user_id):

View File

@ -16,13 +16,12 @@
import logging import logging
from .list_rule import ListRule, ALL_RULE_TYPES, USER_RULE_TYPES, SERVER_RULE_TYPES, ROOM_RULE_TYPES from .list_rule import ListRule, ALL_RULE_TYPES, USER_RULE_TYPES, SERVER_RULE_TYPES, ROOM_RULE_TYPES
from twisted.internet import defer from twisted.internet import defer
from synapse.storage.state import StateFilter
logger = logging.getLogger("synapse.contrib." + __name__) logger = logging.getLogger("synapse.contrib." + __name__)
class BanList(object): class BanList(object):
def __init__(self, hs, room_id): def __init__(self, api, room_id):
self.hs = hs self.api = api
self.room_id = room_id self.room_id = room_id
self.server_rules = [] self.server_rules = []
self.user_rules = [] self.user_rules = []
@ -69,12 +68,5 @@ class BanList(object):
elif event_type in SERVER_RULE_TYPES: elif event_type in SERVER_RULE_TYPES:
self.server_rules.append(rule) self.server_rules.append(rule)
@defer.inlineCallbacks
def get_relevant_state_events(self): def get_relevant_state_events(self):
store = self.hs.get_datastore() return self.api.get_state_events_in_room(self.room_id, [(t, None) for t in ALL_RULE_TYPES])
ev_filter = StateFilter.from_types([(t, None) for t in ALL_RULE_TYPES])
state_ids = yield store.get_filtered_current_state_ids(
room_id=self.room_id, state_filter=ev_filter
)
state = yield store.get_events(state_ids.values())
return state.values()