From a60a0c845fb2da704c591f00078654feb2c01d20 Mon Sep 17 00:00:00 2001 From: jfriedli Date: Mon, 23 Aug 2021 20:56:49 +0200 Subject: [PATCH] validate bulk body is parsable --- matweb/rest_api.py | 12 ++++++++---- test/test_api.py | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/matweb/rest_api.py b/matweb/rest_api.py index 48dcc10..5784bec 100644 --- a/matweb/rest_api.py +++ b/matweb/rest_api.py @@ -7,7 +7,7 @@ from uuid import uuid4 from flask import after_this_request, send_from_directory, Blueprint, current_app from flask_restful import Resource, reqparse, abort, request, url_for, Api -from cerberus import Validator +from cerberus import Validator, DocumentError from werkzeug.datastructures import FileStorage from flasgger import swag_from @@ -157,9 +157,13 @@ class APIBulkDownloadCreator(Resource): if not data: abort(400, message="Post Body Required") current_app.logger.error('BulkDownload - Missing Post Body') - if not self.v.validate(data): - current_app.logger.error('BulkDownload - Missing Post Body: %s', str(self.v.errors)) - abort(400, message=self.v.errors) + try: + if not self.v.validate(data): + current_app.logger.error('BulkDownload - Missing Post Body: %s', str(self.v.errors)) + abort(400, message=self.v.errors) + except DocumentError as e: + abort(400, message="Invalid Post Body") + current_app.logger.error('BulkDownload - Invalid Post Body: %s', str(e)) # prevent the zip file from being overwritten zip_filename = 'files.' + str(uuid4()) + '.zip' zip_path = os.path.join(current_app.config['UPLOAD_FOLDER'], zip_filename) diff --git a/test/test_api.py b/test/test_api.py index 878b0ab..427a1f1 100644 --- a/test/test_api.py +++ b/test/test_api.py @@ -413,6 +413,24 @@ class Mat2APITestCase(unittest.TestCase): request = app.get(download_link) self.assertEqual(code, request.status_code) + def test_download_naughty_input(self): + request = self.app.get( + '/api/download/%F2%8C%BF%BD%F1%AE%98%A3%E4%B7%B8%F2%9B%94%BE%F2%A7%8B%83%F1%B1%80%9F%F3%AA%89%A6/1p/str' + ) + error_message = request.get_json()['message'] + self.assertEqual(404, request.status_code) + self.assertEqual("File not found", error_message) + + def test_download_bulk_naughty_input(self): + request = self.app.post( + '/api/download/bulk', + data='\"\'\'\'&type %SYSTEMROOT%\\\\win.ini\"', + headers={'content-type': 'application/json'} + ) + error_message = request.get_json()['message'] + self.assertEqual(400, request.status_code) + self.assertEqual("Invalid Post Body", error_message) + def test_upload_naughty_input(self): request = self.app.post('/api/upload', data='{"file_name": "\\\\", '