Make the spam checker a module

This commit is contained in:
David Baker 2017-09-26 19:20:23 +01:00
parent ccc67d445b
commit 6cd5fcd536
5 changed files with 33 additions and 23 deletions

View File

@ -34,6 +34,7 @@ from .password_auth_providers import PasswordAuthProviderConfig
from .emailconfig import EmailConfig from .emailconfig import EmailConfig
from .workers import WorkerConfig from .workers import WorkerConfig
from .push import PushConfig from .push import PushConfig
from .spam_checker import SpamCheckerConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
@ -41,7 +42,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, PasswordConfig, EmailConfig, JWTConfig, PasswordConfig, EmailConfig,
WorkerConfig, PasswordAuthProviderConfig, PushConfig,): WorkerConfig, PasswordAuthProviderConfig, PushConfig,
SpamCheckerConfig,):
pass pass

View File

@ -13,8 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
class SpamChecker(object):
def __init__(self, hs):
self.spam_checker = None
def check_event_for_spam(event): if hs.config.spam_checker is not None:
module, config = hs.config.spam_checker
print("cfg %r", config)
self.spam_checker = module(config=config)
def check_event_for_spam(self, event):
"""Checks if a given event is considered "spammy" by this server. """Checks if a given event is considered "spammy" by this server.
If the server considers an event spammy, then it will be rejected if If the server considers an event spammy, then it will be rejected if
@ -27,12 +35,7 @@ def check_event_for_spam(event):
Returns: Returns:
bool: True if the event is spammy. bool: True if the event is spammy.
""" """
if not hasattr(event, "content") or "body" not in event.content: if self.spam_checker is None:
return False return False
# for example: return self.spam_checker.check_event_for_spam(event)
#
# if "the third flower is green" in event.content["body"]:
# return True
return False

View File

@ -16,7 +16,6 @@ import logging
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import spamcheck
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.util import unwrapFirstError, logcontext from synapse.util import unwrapFirstError, logcontext
from twisted.internet import defer from twisted.internet import defer
@ -26,7 +25,7 @@ logger = logging.getLogger(__name__)
class FederationBase(object): class FederationBase(object):
def __init__(self, hs): def __init__(self, hs):
pass self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False, def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
@ -144,7 +143,7 @@ class FederationBase(object):
) )
return redacted return redacted
if spamcheck.check_event_for_spam(pdu): if self.spam_checker.check_event_for_spam(pdu):
logger.warn( logger.warn(
"Event contains spam, redacting %s: %s", "Event contains spam, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json() pdu.event_id, pdu.get_pdu_json()

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.events import spamcheck
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
@ -58,6 +57,8 @@ class MessageHandler(BaseHandler):
self.action_generator = hs.get_action_generator() self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks @defer.inlineCallbacks
def purge_history(self, room_id, event_id): def purge_history(self, room_id, event_id):
event = yield self.store.get_event(event_id) event = yield self.store.get_event(event_id)
@ -322,7 +323,7 @@ class MessageHandler(BaseHandler):
txn_id=txn_id txn_id=txn_id
) )
if spamcheck.check_event_for_spam(event): if self.spam_checker.check_event_for_spam(event):
raise SynapseError( raise SynapseError(
403, "Spam is not permitted here", Codes.FORBIDDEN 403, "Spam is not permitted here", Codes.FORBIDDEN
) )

View File

@ -31,6 +31,7 @@ from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.events.spamcheck import SpamChecker
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.send_queue import FederationRemoteSendQueue
from synapse.federation.transport.client import TransportLayerClient from synapse.federation.transport.client import TransportLayerClient
@ -139,6 +140,7 @@ class HomeServer(object):
'read_marker_handler', 'read_marker_handler',
'action_generator', 'action_generator',
'user_directory_handler', 'user_directory_handler',
'spam_checker',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -309,6 +311,9 @@ class HomeServer(object):
def build_user_directory_handler(self): def build_user_directory_handler(self):
return UserDirectoyHandler(self) return UserDirectoyHandler(self)
def build_spam_checker(self):
return SpamChecker(self)
def remove_pusher(self, app_id, push_key, user_id): def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)