From c36fc9c20e3ec31359f8a77e1bdfb8ade3f0384c Mon Sep 17 00:00:00 2001 From: jfriedli Date: Wed, 2 Oct 2019 08:25:55 -0700 Subject: [PATCH] handle HEAD requests correctly --- main.py | 13 +++++++------ test/test_api.py | 8 ++++++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index 5c6211d..a4766b1 100644 --- a/main.py +++ b/main.py @@ -35,7 +35,7 @@ def create_app(test_config=None): CORS(app, resources={r"/api/*": {"origins": utils.get_allow_origin_header_value()}}) @app.route('/download//') - def download_file(key:str, filename:str): + def download_file(key: str, filename:str): if filename != secure_filename(filename): return redirect(url_for('upload_file')) @@ -173,11 +173,12 @@ def create_app(test_config=None): class APIDownload(Resource): def get(self, key: str, filename: str): complete_path, filepath = is_valid_api_download_file(filename, key) - - @after_this_request - def remove_file(response): - os.remove(complete_path) - return response + # Make sure the file is NOT deleted on HEAD requests + if request.method == 'GET': + @after_this_request + def remove_file(response): + os.remove(complete_path) + return response return send_from_directory(app.config['UPLOAD_FOLDER'], filepath) diff --git a/test/test_api.py b/test/test_api.py index 532ceb9..2029820 100644 --- a/test/test_api.py +++ b/test/test_api.py @@ -151,6 +151,10 @@ class Mat2APITestCase(unittest.TestCase): error = json.loads(request.data.decode('utf-8'))['message'] self.assertEqual(error, 'The file hash does not match') + request = self.app.head(data['download_link']) + self.assertEqual(request.status_code, 200) + self.assertEqual(request.headers['Content-Length'], '633') + request = self.app.get(data['download_link']) self.assertEqual(request.status_code, 200) @@ -210,6 +214,10 @@ class Mat2APITestCase(unittest.TestCase): self.assertIn(response['mime'], 'application/zip') self.assertEqual(response['meta_after'], {}) + request = self.app.head(response['download_link']) + self.assertEqual(request.status_code, 200) + self.assertEqual(request.headers['Content-Length'], '1596') + request = self.app.get(response['download_link']) zip_response = zipfile.ZipFile(BytesIO(request.data)) self.assertEquals(2, len(zip_response.namelist()))