handle HEAD requests correctly

This commit is contained in:
jfriedli 2019-10-02 08:25:55 -07:00
parent d9d4ebf3a2
commit c36fc9c20e
2 changed files with 15 additions and 6 deletions

View file

@ -173,7 +173,8 @@ def create_app(test_config=None):
class APIDownload(Resource): class APIDownload(Resource):
def get(self, key: str, filename: str): def get(self, key: str, filename: str):
complete_path, filepath = is_valid_api_download_file(filename, key) complete_path, filepath = is_valid_api_download_file(filename, key)
# Make sure the file is NOT deleted on HEAD requests
if request.method == 'GET':
@after_this_request @after_this_request
def remove_file(response): def remove_file(response):
os.remove(complete_path) os.remove(complete_path)

View file

@ -151,6 +151,10 @@ class Mat2APITestCase(unittest.TestCase):
error = json.loads(request.data.decode('utf-8'))['message'] error = json.loads(request.data.decode('utf-8'))['message']
self.assertEqual(error, 'The file hash does not match') self.assertEqual(error, 'The file hash does not match')
request = self.app.head(data['download_link'])
self.assertEqual(request.status_code, 200)
self.assertEqual(request.headers['Content-Length'], '633')
request = self.app.get(data['download_link']) request = self.app.get(data['download_link'])
self.assertEqual(request.status_code, 200) self.assertEqual(request.status_code, 200)
@ -210,6 +214,10 @@ class Mat2APITestCase(unittest.TestCase):
self.assertIn(response['mime'], 'application/zip') self.assertIn(response['mime'], 'application/zip')
self.assertEqual(response['meta_after'], {}) self.assertEqual(response['meta_after'], {})
request = self.app.head(response['download_link'])
self.assertEqual(request.status_code, 200)
self.assertEqual(request.headers['Content-Length'], '1596')
request = self.app.get(response['download_link']) request = self.app.get(response['download_link'])
zip_response = zipfile.ZipFile(BytesIO(request.data)) zip_response = zipfile.ZipFile(BytesIO(request.data))
self.assertEquals(2, len(zip_response.namelist())) self.assertEquals(2, len(zip_response.namelist()))