Refactoring

This commit is contained in:
jfriedli 2020-04-23 10:39:35 -07:00
parent d14988fa3f
commit e1bac8b6a7
10 changed files with 376 additions and 310 deletions

254
main.py
View file

@ -1,23 +1,10 @@
import os import os
import hmac
import mimetypes as mtype
from uuid import uuid4
import jinja2 import jinja2
import base64
import io
import binascii
import zipfile
from cerberus import Validator from matweb import utils, rest_api, frontend
import utils from flask import Flask
import file_removal_scheduler from flask_restful import Api
from libmat2 import parser_factory
from flask import Flask, flash, request, redirect, url_for, render_template, send_from_directory, after_this_request
from flask_restful import Resource, Api, reqparse, abort
from werkzeug.utils import secure_filename
from werkzeug.datastructures import FileStorage
from flask_cors import CORS from flask_cors import CORS
from urllib.parse import urljoin
def create_app(test_config=None): def create_app(test_config=None):
@ -32,235 +19,32 @@ def create_app(test_config=None):
if test_config is not None: if test_config is not None:
app.config.update(test_config) app.config.update(test_config)
# Non JS Frontend
app.jinja_loader = jinja2.ChoiceLoader([ # type: ignore app.jinja_loader = jinja2.ChoiceLoader([ # type: ignore
jinja2.FileSystemLoader(app.config['CUSTOM_TEMPLATES_DIR']), jinja2.FileSystemLoader(app.config['CUSTOM_TEMPLATES_DIR']),
app.jinja_loader, app.jinja_loader,
]) ])
app.register_blueprint(frontend.routes)
# Restful API hookup
api = Api(app) api = Api(app)
CORS(app, resources={r"/api/*": {"origins": utils.get_allow_origin_header_value()}}) CORS(app, resources={r"/api/*": {"origins": utils.get_allow_origin_header_value()}})
api.add_resource(
@app.route('/info') rest_api.APIUpload,
def info(): '/api/upload',
get_supported_extensions() resource_class_kwargs={'upload_folder': app.config['UPLOAD_FOLDER']}
return render_template(
'info.html', extensions=get_supported_extensions()
) )
api.add_resource(
@app.route('/download/<string:key>/<string:filename>') rest_api.APIDownload,
def download_file(key: str, filename: str): '/api/download/<string:key>/<string:filename>',
if filename != secure_filename(filename): resource_class_kwargs={'upload_folder': app.config['UPLOAD_FOLDER']}
return redirect(url_for('upload_file'))
complete_path, filepath = get_file_paths(filename)
file_removal_scheduler.run_file_removal_job(app.config['UPLOAD_FOLDER'])
if not os.path.exists(complete_path):
return redirect(url_for('upload_file'))
if hmac.compare_digest(utils.hash_file(complete_path), key) is False:
return redirect(url_for('upload_file'))
@after_this_request
def remove_file(response):
if os.path.exists(complete_path):
os.remove(complete_path)
return response
return send_from_directory(app.config['UPLOAD_FOLDER'], filepath, as_attachment=True)
@app.route('/', methods=['GET', 'POST'])
def upload_file():
utils.check_upload_folder(app.config['UPLOAD_FOLDER'])
mimetypes = get_supported_extensions()
if request.method == 'POST':
if 'file' not in request.files: # check if the post request has the file part
flash('No file part')
return redirect(request.url)
uploaded_file = request.files['file']
if not uploaded_file.filename:
flash('No selected file')
return redirect(request.url)
filename, filepath = save_file(uploaded_file)
parser, mime = get_file_parser(filepath)
if parser is None:
flash('The type %s is not supported' % mime)
return redirect(url_for('upload_file'))
meta = parser.get_meta()
if parser.remove_all() is not True:
flash('Unable to clean %s' % mime)
return redirect(url_for('upload_file'))
key, meta_after, output_filename = cleanup(parser, filepath)
return render_template(
'download.html', mimetypes=mimetypes, meta=meta, filename=output_filename, meta_after=meta_after, key=key
) )
api.add_resource(
max_file_size = int(app.config['MAX_CONTENT_LENGTH'] / 1024 / 1024) rest_api.APIBulkDownloadCreator,
return render_template('index.html', max_file_size=max_file_size, mimetypes=mimetypes) '/api/download/bulk',
resource_class_kwargs={'upload_folder': app.config['UPLOAD_FOLDER']}
def get_supported_extensions():
extensions = set()
for parser in parser_factory._get_parsers():
for m in parser.mimetypes:
extensions |= set(mtype.guess_all_extensions(m, strict=False))
# since `guess_extension` might return `None`, we need to filter it out
return sorted(filter(None, extensions))
def save_file(file):
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(os.path.join(filepath))
return filename, filepath
def get_file_parser(filepath: str):
parser, mime = parser_factory.get_parser(filepath)
return parser, mime
def cleanup(parser, filepath):
output_filename = os.path.basename(parser.output_filename)
parser, _ = parser_factory.get_parser(parser.output_filename)
meta_after = parser.get_meta()
os.remove(filepath)
key = utils.hash_file(os.path.join(app.config['UPLOAD_FOLDER'], output_filename))
return key, meta_after, output_filename
def get_file_paths(filename):
filepath = secure_filename(filename)
complete_path = os.path.join(app.config['UPLOAD_FOLDER'], filepath)
return complete_path, filepath
def is_valid_api_download_file(filename, key):
if filename != secure_filename(filename):
abort(400, message='Insecure filename')
complete_path, filepath = get_file_paths(filename)
if not os.path.exists(complete_path):
abort(404, message='File not found')
if hmac.compare_digest(utils.hash_file(complete_path), key) is False:
abort(400, message='The file hash does not match')
return complete_path, filepath
class APIUpload(Resource):
def post(self):
utils.check_upload_folder(app.config['UPLOAD_FOLDER'])
req_parser = reqparse.RequestParser()
req_parser.add_argument('file_name', type=str, required=True, help='Post parameter is not specified: file_name')
req_parser.add_argument('file', type=str, required=True, help='Post parameter is not specified: file')
args = req_parser.parse_args()
try:
file_data = base64.b64decode(args['file'])
except binascii.Error as err:
abort(400, message='Failed decoding file: ' + str(err))
file = FileStorage(stream=io.BytesIO(file_data), filename=args['file_name'])
filename, filepath = save_file(file)
parser, mime = get_file_parser(filepath)
if parser is None:
abort(415, message='The type %s is not supported' % mime)
meta = parser.get_meta()
if not parser.remove_all():
abort(500, message='Unable to clean %s' % mime)
key, meta_after, output_filename = cleanup(parser, filepath)
return utils.return_file_created_response(
output_filename,
mime,
key,
meta,
meta_after,
urljoin(request.host_url, '%s/%s/%s/%s' % ('api', 'download', key, output_filename))
) )
api.add_resource(rest_api.APISupportedExtensions, '/api/extension')
class APIDownload(Resource):
def get(self, key: str, filename: str):
complete_path, filepath = is_valid_api_download_file(filename, key)
# Make sure the file is NOT deleted on HEAD requests
if request.method == 'GET':
file_removal_scheduler.run_file_removal_job(app.config['UPLOAD_FOLDER'])
@after_this_request
def remove_file(response):
if os.path.exists(complete_path):
os.remove(complete_path)
return response
return send_from_directory(app.config['UPLOAD_FOLDER'], filepath, as_attachment=True)
class APIBulkDownloadCreator(Resource):
schema = {
'download_list': {
'type': 'list',
'minlength': 2,
'maxlength': int(os.environ.get('MAT2_MAX_FILES_BULK_DOWNLOAD', 10)),
'schema': {
'type': 'dict',
'schema': {
'key': {'type': 'string', 'required': True},
'file_name': {'type': 'string', 'required': True}
}
}
}
}
v = Validator(schema)
def post(self):
utils.check_upload_folder(app.config['UPLOAD_FOLDER'])
data = request.json
if not self.v.validate(data):
abort(400, message=self.v.errors)
# prevent the zip file from being overwritten
zip_filename = 'files.' + str(uuid4()) + '.zip'
zip_path = os.path.join(app.config['UPLOAD_FOLDER'], zip_filename)
cleaned_files_zip = zipfile.ZipFile(zip_path, 'w')
with cleaned_files_zip:
for file_candidate in data['download_list']:
complete_path, file_path = is_valid_api_download_file(
file_candidate['file_name'],
file_candidate['key']
)
try:
cleaned_files_zip.write(complete_path)
os.remove(complete_path)
except ValueError:
abort(400, message='Creating the archive failed')
try:
cleaned_files_zip.testzip()
except ValueError as e:
abort(400, message=str(e))
parser, mime = get_file_parser(zip_path)
if not parser.remove_all():
abort(500, message='Unable to clean %s' % mime)
key, meta_after, output_filename = cleanup(parser, zip_path)
return {
'output_filename': output_filename,
'mime': mime,
'key': key,
'meta_after': meta_after,
'download_link': urljoin(request.host_url, '%s/%s/%s/%s' % ('api', 'download', key, output_filename))
}, 201
class APISupportedExtensions(Resource):
def get(self):
return get_supported_extensions()
api.add_resource(APIUpload, '/api/upload')
api.add_resource(APIDownload, '/api/download/<string:key>/<string:filename>')
api.add_resource(APIBulkDownloadCreator, '/api/download/bulk')
api.add_resource(APISupportedExtensions, '/api/extension')
return app return app

77
matweb/frontend.py Normal file
View file

@ -0,0 +1,77 @@
import hmac
import os
from flask import Blueprint, render_template, url_for, current_app, after_this_request, send_from_directory, request, \
flash
from werkzeug.utils import secure_filename, redirect
from matweb import file_removal_scheduler, utils
routes = Blueprint('routes', __name__)
@routes.route('/info')
def info():
utils.get_supported_extensions()
return render_template(
'info.html', extensions=utils.get_supported_extensions()
)
@routes.route('/download/<string:key>/<string:filename>')
def download_file(key: str, filename: str):
if filename != secure_filename(filename):
return redirect(url_for('routes.upload_file'))
complete_path, filepath = utils.get_file_paths(filename, current_app.config['UPLOAD_FOLDER'])
file_removal_scheduler.run_file_removal_job(current_app.config['UPLOAD_FOLDER'])
if not os.path.exists(complete_path):
return redirect(url_for('routes.upload_file'))
if hmac.compare_digest(utils.hash_file(complete_path), key) is False:
return redirect(url_for('routes.upload_file'))
@after_this_request
def remove_file(response):
if os.path.exists(complete_path):
os.remove(complete_path)
return response
return send_from_directory(current_app.config['UPLOAD_FOLDER'], filepath, as_attachment=True)
@routes.route('/', methods=['GET', 'POST'])
def upload_file():
utils.check_upload_folder(current_app.config['UPLOAD_FOLDER'])
mime_types = utils.get_supported_extensions()
if request.method == 'POST':
if 'file' not in request.files: # check if the post request has the file part
flash('No file part')
return redirect(request.url)
uploaded_file = request.files['file']
if not uploaded_file.filename:
flash('No selected file')
return redirect(request.url)
filename, filepath = utils.save_file(uploaded_file, current_app.config['UPLOAD_FOLDER'])
parser, mime = utils.get_file_parser(filepath)
if parser is None:
flash('The type %s is not supported' % mime)
return redirect(url_for('routes.upload_file'))
meta = parser.get_meta()
if parser.remove_all() is not True:
flash('Unable to clean %s' % mime)
return redirect(url_for('routes.upload_file'))
key, meta_after, output_filename = utils.cleanup(parser, filepath, current_app.config['UPLOAD_FOLDER'])
return render_template(
'download.html', mimetypes=mime_types, meta=meta, filename=output_filename, meta_after=meta_after, key=key
)
max_file_size = int(current_app.config['MAX_CONTENT_LENGTH'] / 1024 / 1024)
return render_template('index.html', max_file_size=max_file_size, mimetypes=mime_types)

139
matweb/rest_api.py Normal file
View file

@ -0,0 +1,139 @@
import os
import base64
import io
import binascii
import zipfile
from uuid import uuid4
from flask import after_this_request, send_from_directory
from flask_restful import Resource, reqparse, abort, request
from cerberus import Validator
from werkzeug.datastructures import FileStorage
from urllib.parse import urljoin
from matweb import file_removal_scheduler, utils
class APIUpload(Resource):
def __init__(self, **kwargs):
self.upload_folder = kwargs['upload_folder']
def post(self):
utils.check_upload_folder(self.upload_folder)
req_parser = reqparse.RequestParser()
req_parser.add_argument('file_name', type=str, required=True, help='Post parameter is not specified: file_name')
req_parser.add_argument('file', type=str, required=True, help='Post parameter is not specified: file')
args = req_parser.parse_args()
try:
file_data = base64.b64decode(args['file'])
except binascii.Error as err:
abort(400, message='Failed decoding file: ' + str(err))
file = FileStorage(stream=io.BytesIO(file_data), filename=args['file_name'])
filename, filepath = utils.save_file(file, self.upload_folder)
parser, mime = utils.get_file_parser(filepath)
if parser is None:
abort(415, message='The type %s is not supported' % mime)
meta = parser.get_meta()
if not parser.remove_all():
abort(500, message='Unable to clean %s' % mime)
key, meta_after, output_filename = utils.cleanup(parser, filepath, self.upload_folder)
return utils.return_file_created_response(
output_filename,
mime,
key,
meta,
meta_after,
urljoin(request.host_url, '%s/%s/%s/%s' % ('api', 'download', key, output_filename))
)
class APIDownload(Resource):
def __init__(self, **kwargs):
self.upload_folder = kwargs['upload_folder']
def get(self, key: str, filename: str):
complete_path, filepath = utils.is_valid_api_download_file(filename, key, self.upload_folder)
# Make sure the file is NOT deleted on HEAD requests
if request.method == 'GET':
file_removal_scheduler.run_file_removal_job(self.upload_folder)
@after_this_request
def remove_file(response):
if os.path.exists(complete_path):
os.remove(complete_path)
return response
return send_from_directory(self.upload_folder, filepath, as_attachment=True)
class APIBulkDownloadCreator(Resource):
def __init__(self, **kwargs):
self.upload_folder = kwargs['upload_folder']
schema = {
'download_list': {
'type': 'list',
'minlength': 2,
'maxlength': int(os.environ.get('MAT2_MAX_FILES_BULK_DOWNLOAD', 10)),
'schema': {
'type': 'dict',
'schema': {
'key': {'type': 'string', 'required': True},
'file_name': {'type': 'string', 'required': True}
}
}
}
}
v = Validator(schema)
def post(self):
utils.check_upload_folder(self.upload_folder)
data = request.json
if not self.v.validate(data):
abort(400, message=self.v.errors)
# prevent the zip file from being overwritten
zip_filename = 'files.' + str(uuid4()) + '.zip'
zip_path = os.path.join(self.upload_folder, zip_filename)
cleaned_files_zip = zipfile.ZipFile(zip_path, 'w')
with cleaned_files_zip:
for file_candidate in data['download_list']:
complete_path, file_path = utils.is_valid_api_download_file(
file_candidate['file_name'],
file_candidate['key'],
self.upload_folder
)
try:
cleaned_files_zip.write(complete_path)
os.remove(complete_path)
except ValueError:
abort(400, message='Creating the archive failed')
try:
cleaned_files_zip.testzip()
except ValueError as e:
abort(400, message=str(e))
parser, mime = utils.get_file_parser(zip_path)
if not parser.remove_all():
abort(500, message='Unable to clean %s' % mime)
key, meta_after, output_filename = utils.cleanup(parser, zip_path, self.upload_folder)
return {
'output_filename': output_filename,
'mime': mime,
'key': key,
'meta_after': meta_after,
'download_link': urljoin(request.host_url, '%s/%s/%s/%s' % ('api', 'download', key, output_filename))
}, 201
class APISupportedExtensions(Resource):
def get(self):
return utils.get_supported_extensions()

91
matweb/utils.py Normal file
View file

@ -0,0 +1,91 @@
import hmac
import os
import hashlib
import mimetypes as mtype
from flask_restful import abort
from libmat2 import parser_factory
from werkzeug.utils import secure_filename
def get_allow_origin_header_value():
return os.environ.get('MAT2_ALLOW_ORIGIN_WHITELIST', '*').split(" ")
def hash_file(filepath: str) -> str:
sha256 = hashlib.sha256()
with open(filepath, 'rb') as f:
while True:
data = f.read(65536) # read the file by chunk of 64k
if not data:
break
sha256.update(data)
return sha256.hexdigest()
def check_upload_folder(upload_folder):
if not os.path.exists(upload_folder):
os.mkdir(upload_folder)
def return_file_created_response(output_filename, mime, key, meta, meta_after, download_link):
return {
'output_filename': output_filename,
'mime': mime,
'key': key,
'meta': meta,
'meta_after': meta_after,
'download_link': download_link
}
def get_supported_extensions():
extensions = set()
for parser in parser_factory._get_parsers():
for m in parser.mimetypes:
extensions |= set(mtype.guess_all_extensions(m, strict=False))
# since `guess_extension` might return `None`, we need to filter it out
return sorted(filter(None, extensions))
def save_file(file, upload_folder):
filename = secure_filename(file.filename)
filepath = os.path.join(upload_folder, filename)
file.save(os.path.join(filepath))
return filename, filepath
def get_file_parser(filepath: str):
parser, mime = parser_factory.get_parser(filepath)
return parser, mime
def cleanup(parser, filepath, upload_folder):
output_filename = os.path.basename(parser.output_filename)
parser, _ = parser_factory.get_parser(parser.output_filename)
meta_after = parser.get_meta()
os.remove(filepath)
key = hash_file(os.path.join(upload_folder, output_filename))
return key, meta_after, output_filename
def get_file_paths(filename, upload_folder):
filepath = secure_filename(filename)
complete_path = os.path.join(upload_folder, filepath)
return complete_path, filepath
def is_valid_api_download_file(filename, key, upload_folder):
if filename != secure_filename(filename):
abort(400, message='Insecure filename')
complete_path, filepath = get_file_paths(filename, upload_folder)
if not os.path.exists(complete_path):
abort(404, message='File not found')
if hmac.compare_digest(hash_file(complete_path), key) is False:
abort(400, message='The file hash does not match')
return complete_path, filepath

View file

@ -10,7 +10,7 @@
{% endif %} {% endif %}
<div class="uk-flex uk-flex-center"> <div class="uk-flex uk-flex-center">
<div> <div>
<a class="uk-flex-1" href='{{ url_for('download_file', key=key, filename=filename) }}'> <a class="uk-flex-1" href='{{ url_for('routes.download_file', key=key, filename=filename) }}'>
<button class="uk-button uk-button-primary"> <button class="uk-button uk-button-primary">
⇩ download cleaned file ⇩ download cleaned file
</button> </button>

View file

@ -13,11 +13,12 @@ import main
class Mat2WebTestCase(unittest.TestCase): class Mat2WebTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
os.environ.setdefault('MAT2_ALLOW_ORIGIN_WHITELIST', 'origin1.gnu origin2.gnu') os.environ.setdefault('MAT2_ALLOW_ORIGIN_WHITELIST', 'origin1.gnu origin2.gnu')
app = main.create_app()
self.upload_folder = tempfile.mkdtemp() self.upload_folder = tempfile.mkdtemp()
app.config.update( app = main.create_app(
TESTING=True, test_config={
UPLOAD_FOLDER=self.upload_folder 'TESTING': True,
'UPLOAD_FOLDER': self.upload_folder
}
) )
self.app = app.test_client() self.app = app.test_client()
@ -127,7 +128,7 @@ class Mat2WebTestCase(unittest.TestCase):
rv = self.app.get('/download/70623619c449a040968cdbea85945bf384fa30ed2d5d24fa3/test.cleaned.txt') rv = self.app.get('/download/70623619c449a040968cdbea85945bf384fa30ed2d5d24fa3/test.cleaned.txt')
self.assertEqual(rv.status_code, 302) self.assertEqual(rv.status_code, 302)
@patch('file_removal_scheduler.random.randint') @patch('matweb.file_removal_scheduler.random.randint')
def test_upload_leftover(self, randint_mock): def test_upload_leftover(self, randint_mock):
randint_mock.return_value = 0 randint_mock.return_value = 0
os.environ['MAT2_MAX_FILE_AGE_FOR_REMOVAL'] = '0' os.environ['MAT2_MAX_FILE_AGE_FOR_REMOVAL'] = '0'

View file

@ -14,12 +14,14 @@ import main
class Mat2APITestCase(unittest.TestCase): class Mat2APITestCase(unittest.TestCase):
def setUp(self): def setUp(self):
os.environ.setdefault('MAT2_ALLOW_ORIGIN_WHITELIST', 'origin1.gnu origin2.gnu') os.environ.setdefault('MAT2_ALLOW_ORIGIN_WHITELIST', 'origin1.gnu origin2.gnu')
app = main.create_app()
self.upload_folder = tempfile.mkdtemp() self.upload_folder = tempfile.mkdtemp()
app.config.update( app = main.create_app(
TESTING=True, test_config={
UPLOAD_FOLDER=self.upload_folder 'TESTING': True,
'UPLOAD_FOLDER': self.upload_folder
}
) )
self.app = app.test_client() self.app = app.test_client()
def tearDown(self): def tearDown(self):
@ -38,7 +40,7 @@ class Mat2APITestCase(unittest.TestCase):
self.assertEqual(request.headers['Access-Control-Allow-Origin'], 'origin1.gnu') self.assertEqual(request.headers['Access-Control-Allow-Origin'], 'origin1.gnu')
self.assertEqual(request.status_code, 200) self.assertEqual(request.status_code, 200)
data = json.loads(request.data.decode('utf-8')) data = request.get_json()
expected = { expected = {
'output_filename': 'test_name.cleaned.jpg', 'output_filename': 'test_name.cleaned.jpg',
'mime': 'image/jpeg', 'mime': 'image/jpeg',
@ -64,7 +66,7 @@ class Mat2APITestCase(unittest.TestCase):
self.assertEqual(request.headers['Content-Type'], 'application/json') self.assertEqual(request.headers['Content-Type'], 'application/json')
self.assertEqual(request.status_code, 400) self.assertEqual(request.status_code, 400)
error = json.loads(request.data.decode('utf-8'))['message'] error = request.get_json()['message']
self.assertEqual(error['file'], 'Post parameter is not specified: file') self.assertEqual(error['file'], 'Post parameter is not specified: file')
request = self.app.post('/api/upload', request = self.app.post('/api/upload',
@ -74,7 +76,7 @@ class Mat2APITestCase(unittest.TestCase):
self.assertEqual(request.headers['Content-Type'], 'application/json') self.assertEqual(request.headers['Content-Type'], 'application/json')
self.assertEqual(request.status_code, 400) self.assertEqual(request.status_code, 400)
error = json.loads(request.data.decode('utf-8'))['message'] error = request.get_json()['message']
self.assertEqual(error, 'Failed decoding file: Incorrect padding') self.assertEqual(error, 'Failed decoding file: Incorrect padding')
def test_api_not_supported(self): def test_api_not_supported(self):
@ -87,7 +89,7 @@ class Mat2APITestCase(unittest.TestCase):
self.assertEqual(request.headers['Content-Type'], 'application/json') self.assertEqual(request.headers['Content-Type'], 'application/json')
self.assertEqual(request.status_code, 415) self.assertEqual(request.status_code, 415)
error = json.loads(request.data.decode('utf-8'))['message'] error = request.get_json()['message']
self.assertEqual(error, 'The type application/pdf is not supported') self.assertEqual(error, 'The type application/pdf is not supported')
def test_api_supported_extensions(self): def test_api_supported_extensions(self):
@ -136,7 +138,7 @@ class Mat2APITestCase(unittest.TestCase):
'iaj111eAsAAQTpAwAABOkDAABQSwUGAAAAAAIAAgC8AAAAwAAAAAAA"}', 'iaj111eAsAAQTpAwAABOkDAABQSwUGAAAAAAIAAgC8AAAAwAAAAAAA"}',
headers={'content-type': 'application/json'} headers={'content-type': 'application/json'}
) )
error = json.loads(request.data.decode('utf-8'))['message'] error = request.get_json()['message']
self.assertEqual(error, 'Unable to clean application/zip') self.assertEqual(error, 'Unable to clean application/zip')
@ -148,25 +150,25 @@ class Mat2APITestCase(unittest.TestCase):
headers={'content-type': 'application/json'} headers={'content-type': 'application/json'}
) )
self.assertEqual(request.status_code, 200) self.assertEqual(request.status_code, 200)
data = json.loads(request.data.decode('utf-8')) data = request.get_json()
request = self.app.get('http://localhost/api/download/' request = self.app.get('http://localhost/api/download/'
'81a541f9ebc0233d419d25ed39908b16f82be26a783f32d56c381559e84e6161/test name.cleaned.jpg') '81a541f9ebc0233d419d25ed39908b16f82be26a783f32d56c381559e84e6161/test name.cleaned.jpg')
self.assertEqual(request.status_code, 400) self.assertEqual(request.status_code, 400)
error = json.loads(request.data.decode('utf-8'))['message'] error = request.get_json()['message']
self.assertEqual(error, 'Insecure filename') self.assertEqual(error, 'Insecure filename')
request = self.app.get('http://localhost/api/download/' request = self.app.get('http://localhost/api/download/'
'81a541f9ebc0233d419d25ed39908b16f82be26a783f32d56c381559e84e6161/' '81a541f9ebc0233d419d25ed39908b16f82be26a783f32d56c381559e84e6161/'
'wrong_file_name.jpg') 'wrong_file_name.jpg')
self.assertEqual(request.status_code, 404) self.assertEqual(request.status_code, 404)
error = json.loads(request.data.decode('utf-8'))['message'] error = request.get_json()['message']
self.assertEqual(error, 'File not found') self.assertEqual(error, 'File not found')
request = self.app.get('http://localhost/api/download/81a541f9e/test_name.cleaned.jpg') request = self.app.get('http://localhost/api/download/81a541f9e/test_name.cleaned.jpg')
self.assertEqual(request.status_code, 400) self.assertEqual(request.status_code, 400)
error = json.loads(request.data.decode('utf-8'))['message'] error = request.get_json()['message']
self.assertEqual(error, 'The file hash does not match') self.assertEqual(error, 'The file hash does not match')
request = self.app.head(data['download_link']) request = self.app.head(data['download_link'])
@ -188,7 +190,7 @@ class Mat2APITestCase(unittest.TestCase):
headers={'content-type': 'application/json'} headers={'content-type': 'application/json'}
) )
self.assertEqual(request.status_code, 200) self.assertEqual(request.status_code, 200)
upload_one = json.loads(request.data.decode('utf-8')) upload_one = request.get_json()
request = self.app.post('/api/upload', request = self.app.post('/api/upload',
data='{"file_name": "test_name_two.jpg", ' data='{"file_name": "test_name_two.jpg", '
@ -197,7 +199,7 @@ class Mat2APITestCase(unittest.TestCase):
headers={'content-type': 'application/json'} headers={'content-type': 'application/json'}
) )
self.assertEqual(request.status_code, 200) self.assertEqual(request.status_code, 200)
upload_two = json.loads(request.data.decode('utf-8')) upload_two = request.get_json()
post_body = { post_body = {
u'download_list': [ u'download_list': [
@ -216,7 +218,7 @@ class Mat2APITestCase(unittest.TestCase):
headers={'content-type': 'application/json'} headers={'content-type': 'application/json'}
) )
response = json.loads(request.data.decode('utf-8')) response = request.get_json()
self.assertEqual(request.status_code, 201) self.assertEqual(request.status_code, 201)
self.assertIn( self.assertIn(
@ -268,7 +270,7 @@ class Mat2APITestCase(unittest.TestCase):
headers={'content-type': 'application/json'} headers={'content-type': 'application/json'}
) )
response = json.loads(request.data.decode('utf-8')) response = request.get_json()
self.assertEqual(response['message']['download_list'][0], 'min length is 2') self.assertEqual(response['message']['download_list'][0], 'min length is 2')
self.assertEqual(request.status_code, 400) self.assertEqual(request.status_code, 400)
@ -280,7 +282,7 @@ class Mat2APITestCase(unittest.TestCase):
headers={'content-type': 'application/json'} headers={'content-type': 'application/json'}
) )
response = json.loads(request.data.decode('utf-8')) response = request.get_json()
self.assertEqual(response['message']['download_list'][0]['0'][0]['file_name'][0], 'required field') self.assertEqual(response['message']['download_list'][0]['0'][0]['file_name'][0], 'required field')
self.assertEqual(response['message']['download_list'][0]['0'][0]['key'][0], 'required field') self.assertEqual(response['message']['download_list'][0]['0'][0]['key'][0], 'required field')
self.assertEqual(request.status_code, 400) self.assertEqual(request.status_code, 400)
@ -338,7 +340,7 @@ class Mat2APITestCase(unittest.TestCase):
headers={'content-type': 'application/json'} headers={'content-type': 'application/json'}
) )
response = json.loads(request.data.decode('utf-8')) response = request.get_json()
self.assertEqual(response['message']['download_list'][0], 'max length is 10') self.assertEqual(response['message']['download_list'][0], 'max length is 10')
self.assertEqual(request.status_code, 400) self.assertEqual(request.status_code, 400)
@ -358,17 +360,18 @@ class Mat2APITestCase(unittest.TestCase):
data=json.dumps(post_body), data=json.dumps(post_body),
headers={'content-type': 'application/json'} headers={'content-type': 'application/json'}
) )
response = json.loads(request.data.decode('utf-8')) response = request.get_json()
self.assertEqual('File not found', response['message']) self.assertEqual('File not found', response['message'])
@patch('file_removal_scheduler.random.randint') @patch('matweb.file_removal_scheduler.random.randint')
def test_api_upload_leftover(self, randint_mock): def test_api_upload_leftover(self, randint_mock):
os.environ['MAT2_MAX_FILE_AGE_FOR_REMOVAL'] = '0' os.environ['MAT2_MAX_FILE_AGE_FOR_REMOVAL'] = '0'
app = main.create_app()
self.upload_folder = tempfile.mkdtemp() self.upload_folder = tempfile.mkdtemp()
app.config.update( app = main.create_app(
TESTING=True, test_config={
UPLOAD_FOLDER=self.upload_folder 'TESTING': True,
'UPLOAD_FOLDER': self.upload_folder
}
) )
app = app.test_client() app = app.test_client()
randint_mock.return_value = 1 randint_mock.return_value = 1
@ -385,7 +388,7 @@ class Mat2APITestCase(unittest.TestCase):
'FcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="}', 'FcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="}',
headers={'content-type': 'application/json'} headers={'content-type': 'application/json'}
) )
download_link = json.loads(request.data.decode('utf-8'))['download_link'] download_link = request.get_json()['download_link']
request = app.get(download_link) request = app.get(download_link)
self.assertEqual(code, request.status_code) self.assertEqual(code, request.status_code)

View file

@ -3,7 +3,9 @@ import tempfile
from os import path, environ from os import path, environ
import shutil import shutil
import file_removal_scheduler from unittest.mock import patch
from matweb import file_removal_scheduler
import main import main
@ -17,12 +19,13 @@ class Mat2WebTestCase(unittest.TestCase):
) )
self.app = app self.app = app
def test_removal(self): @patch('matweb.file_removal_scheduler.random.randint')
def test_removal(self, randint_mock):
filename = 'test_name.cleaned.jpg' filename = 'test_name.cleaned.jpg'
environ['MAT2_MAX_FILE_AGE_FOR_REMOVAL'] = '0' environ['MAT2_MAX_FILE_AGE_FOR_REMOVAL'] = '0'
open(path.join(self.upload_folder, filename), 'a').close() open(path.join(self.upload_folder, filename), 'a').close()
self.assertTrue(path.exists(path.join(self.upload_folder, ))) self.assertTrue(path.exists(path.join(self.upload_folder, )))
for i in range(0, 11): randint_mock.return_value = 0
file_removal_scheduler.run_file_removal_job(self.app.config['UPLOAD_FOLDER']) file_removal_scheduler.run_file_removal_job(self.app.config['UPLOAD_FOLDER'])
self.assertFalse(path.exists(path.join(self.upload_folder, filename))) self.assertFalse(path.exists(path.join(self.upload_folder, filename)))
@ -30,12 +33,13 @@ class Mat2WebTestCase(unittest.TestCase):
file_removal_scheduler.run_file_removal_job(self.app.config['UPLOAD_FOLDER']) file_removal_scheduler.run_file_removal_job(self.app.config['UPLOAD_FOLDER'])
self.assertTrue(path.exists(path.join(self.upload_folder, ))) self.assertTrue(path.exists(path.join(self.upload_folder, )))
def test_non_removal(self): @patch('matweb.file_removal_scheduler.random.randint')
def test_non_removal(self, randint_mock):
filename = u'i_should_no_be_removed.txt' filename = u'i_should_no_be_removed.txt'
environ['MAT2_MAX_FILE_AGE_FOR_REMOVAL'] = '9999999' environ['MAT2_MAX_FILE_AGE_FOR_REMOVAL'] = '9999999'
open(path.join(self.upload_folder, filename), 'a').close() open(path.join(self.upload_folder, filename), 'a').close()
self.assertTrue(path.exists(path.join(self.upload_folder, filename))) self.assertTrue(path.exists(path.join(self.upload_folder, filename)))
for i in range(0, 11): randint_mock.return_value = 0
file_removal_scheduler.run_file_removal_job(self.app.config['UPLOAD_FOLDER']) file_removal_scheduler.run_file_removal_job(self.app.config['UPLOAD_FOLDER'])
self.assertTrue(path.exists(path.join(self.upload_folder, filename))) self.assertTrue(path.exists(path.join(self.upload_folder, filename)))

View file

@ -1,33 +0,0 @@
import os
import hashlib
def get_allow_origin_header_value():
return os.environ.get('MAT2_ALLOW_ORIGIN_WHITELIST', '*').split(" ")
def hash_file(filepath: str) -> str:
sha256 = hashlib.sha256()
with open(filepath, 'rb') as f:
while True:
data = f.read(65536) # read the file by chunk of 64k
if not data:
break
sha256.update(data)
return sha256.hexdigest()
def check_upload_folder(upload_folder):
if not os.path.exists(upload_folder):
os.mkdir(upload_folder)
def return_file_created_response(output_filename, mime, key, meta, meta_after, download_link):
return {
'output_filename': output_filename,
'mime': mime,
'key': key,
'meta': meta,
'meta_after': meta_after,
'download_link': download_link
}