Add check_media_file_for_spam spam checker hook

This commit is contained in:
Erik Johnston 2021-02-03 16:44:16 +00:00
parent afa18f1baa
commit 7e8083eb48
6 changed files with 210 additions and 6 deletions

1
changelog.d/9311.feature Normal file
View File

@ -0,0 +1 @@
Add hook to spam checker modules that allow checking file uploads and remote downloads.

View File

@ -61,6 +61,9 @@ class ExampleSpamChecker:
async def check_registration_for_spam(self, email_threepid, username, request_info): async def check_registration_for_spam(self, email_threepid, username, request_info):
return RegistrationBehaviour.ALLOW # allow all registrations return RegistrationBehaviour.ALLOW # allow all registrations
async def check_media_file_for_spam(self, file_wrapper, file_info):
return False # allow all media
``` ```
## Configuration ## Configuration

View File

@ -17,6 +17,8 @@
import inspect import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import ReadableFileWrapper
from synapse.spam_checker_api import RegistrationBehaviour from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import Collection from synapse.types import Collection
from synapse.util.async_helpers import maybe_awaitable from synapse.util.async_helpers import maybe_awaitable
@ -214,3 +216,48 @@ class SpamChecker:
return behaviour return behaviour
return RegistrationBehaviour.ALLOW return RegistrationBehaviour.ALLOW
async def check_media_file_for_spam(
self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
) -> bool:
"""Checks if a piece of newly uploaded media should be blocked.
This will be called for local uploads, downloads of remote media, each
thumbnail generated for those, and web pages/images used for URL
previews.
Note that care should be taken to not do blocking IO operations in the
main thread. For example, to get the contents of a file a module
should do::
async def check_media_file_for_spam(
self, file: ReadableFileWrapper, file_info: FileInfo
) -> bool:
buffer = BytesIO()
await file.write_chunks_to(buffer.write)
if buffer.getvalue() == b"Hello World":
return True
return False
Args:
file: An object that allows reading the contents of the media.
file_info: Metadata about the file.
Returns:
True if the media should be blocked or False if it should be
allowed.
"""
for spam_checker in self.spam_checkers:
# For backwards compatibility, only run if the method exists on the
# spam checker
checker = getattr(spam_checker, "check_media_file_for_spam", None)
if checker:
spam = await maybe_awaitable(checker(file_wrapper, file_info))
if spam:
return True
return False

View File

@ -16,13 +16,17 @@ import contextlib
import logging import logging
import os import os
import shutil import shutil
from typing import IO, TYPE_CHECKING, Any, Optional, Sequence from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence
import attr
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IConsumer from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from synapse.api.errors import NotFoundError
from synapse.logging.context import defer_to_thread, make_deferred_yieldable from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util import Clock
from synapse.util.file_consumer import BackgroundFileConsumer from synapse.util.file_consumer import BackgroundFileConsumer
from ._base import FileInfo, Responder from ._base import FileInfo, Responder
@ -58,6 +62,8 @@ class MediaStorage:
self.local_media_directory = local_media_directory self.local_media_directory = local_media_directory
self.filepaths = filepaths self.filepaths = filepaths
self.storage_providers = storage_providers self.storage_providers = storage_providers
self.spam_checker = hs.get_spam_checker()
self.clock = hs.get_clock()
async def store_file(self, source: IO, file_info: FileInfo) -> str: async def store_file(self, source: IO, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other """Write `source` to the on disk media store, and also any other
@ -127,18 +133,29 @@ class MediaStorage:
f.flush() f.flush()
f.close() f.close()
spam = await self.spam_checker.check_media_file_for_spam(
ReadableFileWrapper(self.clock, fname), file_info
)
if spam:
logger.info("Blocking media due to spam checker")
# Note that we'll delete the stored media, due to the
# try/except below. The media also won't be stored in
# the DB.
raise SpamMediaException()
for provider in self.storage_providers: for provider in self.storage_providers:
await provider.store_file(path, file_info) await provider.store_file(path, file_info)
finished_called[0] = True finished_called[0] = True
yield f, fname, finish yield f, fname, finish
except Exception: except Exception as e:
try: try:
os.remove(fname) os.remove(fname)
except Exception: except Exception:
pass pass
raise
raise e from None
if not finished_called: if not finished_called:
raise Exception("Finished callback not called") raise Exception("Finished callback not called")
@ -302,3 +319,39 @@ class FileResponder(Responder):
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.open_file.close() self.open_file.close()
class SpamMediaException(NotFoundError):
"""The media was blocked by a spam checker, so we simply 404 the request (in
the same way as if it was quarantined).
"""
@attr.s(slots=True)
class ReadableFileWrapper:
"""Wrapper that allows reading a file in chunks, yielding to the reactor,
and writing to a callback.
This is simplified `FileSender` that takes an IO object rather than an
`IConsumer`.
"""
CHUNK_SIZE = 2 ** 14
clock = attr.ib(type=Clock)
path = attr.ib(type=str)
async def write_chunks_to(self, callback: Callable[[bytes], None]):
"""Reads the file in chunks and calls the callback with each chunk.
"""
with open(self.path, "rb") as file:
while True:
chunk = file.read(self.CHUNK_SIZE)
if not chunk:
break
callback(chunk)
# We yield to the reactor by sleeping for 0 seconds.
await self.clock.sleep(0)

View File

@ -22,6 +22,7 @@ from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
from synapse.rest.media.v1.media_storage import SpamMediaException
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer from synapse.app.homeserver import HomeServer
@ -86,9 +87,14 @@ class UploadResource(DirectServeJsonResource):
# disposition = headers.getRawHeaders(b"Content-Disposition")[0] # disposition = headers.getRawHeaders(b"Content-Disposition")[0]
# TODO(markjh): parse content-dispostion # TODO(markjh): parse content-dispostion
try:
content_uri = await self.media_repo.create_content( content_uri = await self.media_repo.create_content(
media_type, upload_name, request.content, content_length, requester.user media_type, upload_name, request.content, content_length, requester.user
) )
except SpamMediaException:
# For uploading of media we want to respond with a 400, instead of
# the default 404, as that would just be confusing.
raise SynapseError(400, "Bad content")
logger.info("Uploaded content with URI %r", content_uri) logger.info("Uploaded content with URI %r", content_uri)

View File

@ -30,6 +30,8 @@ from twisted.internet import defer
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.rest import admin
from synapse.rest.client.v1 import login
from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.media_storage import MediaStorage
@ -37,6 +39,7 @@ from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request from tests.server import FakeSite, make_request
from tests.utils import default_config
class MediaStorageTests(unittest.HomeserverTestCase): class MediaStorageTests(unittest.HomeserverTestCase):
@ -398,3 +401,94 @@ class MediaRepoTests(unittest.HomeserverTestCase):
headers.getRawHeaders(b"X-Robots-Tag"), headers.getRawHeaders(b"X-Robots-Tag"),
[b"noindex, nofollow, noarchive, noimageindex"], [b"noindex, nofollow, noarchive, noimageindex"],
) )
class TestSpamChecker:
"""A spam checker module that rejects all media that includes the bytes
`evil`.
"""
def __init__(self, config, api):
self.config = config
self.api = api
def parse_config(config):
return config
async def check_event_for_spam(self, foo):
return False # allow all events
async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
return True # allow all invites
async def user_may_create_room(self, userid):
return True # allow all room creations
async def user_may_create_room_alias(self, userid, room_alias):
return True # allow all room aliases
async def user_may_publish_room(self, userid, room_id):
return True # allow publishing of all rooms
async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
buf = BytesIO()
await file_wrapper.write_chunks_to(buf.write)
return b"evil" in buf.getvalue()
class SpamCheckerTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
admin.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.user = self.register_user("user", "pass")
self.tok = self.login("user", "pass")
# Allow for uploading and downloading to/from the media repo
self.media_repo = hs.get_media_repository_resource()
self.download_resource = self.media_repo.children[b"download"]
self.upload_resource = self.media_repo.children[b"upload"]
def default_config(self):
config = default_config("test")
config.update(
{
"spam_checker": [
{
"module": TestSpamChecker.__module__ + ".TestSpamChecker",
"config": {},
}
]
}
)
return config
def test_upload_innocent(self):
"""Attempt to upload some innocent data that should be allowed.
"""
image_data = unhexlify(
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
b"0000001f15c4890000000a49444154789c63000100000500010d"
b"0a2db40000000049454e44ae426082"
)
self.helper.upload_media(
self.upload_resource, image_data, tok=self.tok, expect_code=200
)
def test_upload_ban(self):
"""Attempt to upload some data that includes bytes "evil", which should
get rejected by the spam checker.
"""
data = b"Some evil data"
self.helper.upload_media(
self.upload_resource, data, tok=self.tok, expect_code=400
)