Fix CLI tests

This commit is contained in:
Micah Lee 2021-12-01 21:01:32 -08:00
parent c8ba508d26
commit b3d53ca2f2
No known key found for this signature in database
GPG Key ID: 403C2657CD994F73
4 changed files with 46 additions and 32 deletions

View File

@ -541,7 +541,9 @@ class ZipWriter(object):
filename. filename.
""" """
def __init__(self, common, web, zip_filename=None, processed_size_callback=None): def __init__(
self, common, web=None, zip_filename=None, processed_size_callback=None
):
self.common = common self.common = common
self.web = web self.web = web
self.cancel_compression = False self.cancel_compression = False
@ -555,6 +557,7 @@ class ZipWriter(object):
self.zip_filename = f"{self.zip_temp_dir.name}/onionshare_{self.common.random_string(4, 6)}.zip" self.zip_filename = f"{self.zip_temp_dir.name}/onionshare_{self.common.random_string(4, 6)}.zip"
# Cleanup this temp dir # Cleanup this temp dir
if self.web:
self.web.cleanup_tempdirs.append(self.zip_temp_dir) self.web.cleanup_tempdirs.append(self.zip_temp_dir)
self.z = zipfile.ZipFile(self.zip_filename, "w", allowZip64=True) self.z = zipfile.ZipFile(self.zip_filename, "w", allowZip64=True)

View File

@ -37,7 +37,7 @@ def temp_dir():
"""Creates a persistent temporary directory for the CLI tests to use""" """Creates a persistent temporary directory for the CLI tests to use"""
global test_temp_dir global test_temp_dir
if not test_temp_dir: if not test_temp_dir:
test_temp_dir = tempfile.mkdtemp() test_temp_dir = tempfile.TemporaryDirectory()
return test_temp_dir return test_temp_dir
@ -47,10 +47,9 @@ def temp_dir_1024(temp_dir):
particular size (1024 bytes). particular size (1024 bytes).
""" """
new_temp_dir = tempfile.mkdtemp(dir=temp_dir) new_temp_dir = tempfile.TemporaryDirectory(dir=temp_dir.name)
tmp_file, tmp_file_path = tempfile.mkstemp(dir=new_temp_dir) tmp_file = tempfile.NamedTemporaryFile(dir=new_temp_dir.name)
with open(tmp_file, "wb") as f: tmp_file.write(b"*" * 1024)
f.write(b"*" * 1024)
return new_temp_dir return new_temp_dir
@ -61,9 +60,8 @@ def temp_dir_1024_delete(temp_dir):
the file inside) will be deleted after fixture usage. the file inside) will be deleted after fixture usage.
""" """
with tempfile.TemporaryDirectory(dir=temp_dir) as new_temp_dir: with tempfile.TemporaryDirectory(dir=temp_dir.name) as new_temp_dir:
tmp_file, tmp_file_path = tempfile.mkstemp(dir=new_temp_dir) with open(os.path.join(new_temp_dir, "file"), "wb") as f:
with open(tmp_file, "wb") as f:
f.write(b"*" * 1024) f.write(b"*" * 1024)
yield new_temp_dir yield new_temp_dir
@ -72,9 +70,10 @@ def temp_dir_1024_delete(temp_dir):
def temp_file_1024(temp_dir): def temp_file_1024(temp_dir):
"""Create a temporary file of a particular size (1024 bytes).""" """Create a temporary file of a particular size (1024 bytes)."""
with tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) as tmp_file: filename = os.path.join(temp_dir.name, "file")
tmp_file.write(b"*" * 1024) with open(filename, "wb") as f:
return tmp_file.name f.write(b"*" * 1024)
return filename
@pytest.fixture @pytest.fixture
@ -84,11 +83,11 @@ def temp_file_1024_delete(temp_dir):
The temporary file will be deleted after fixture usage. The temporary file will be deleted after fixture usage.
""" """
with tempfile.NamedTemporaryFile(dir=temp_dir, delete=False) as tmp_file: with tempfile.NamedTemporaryFile(dir=temp_dir.name, delete=False) as tmp_file:
tmp_file.write(b"*" * 1024) tmp_file.write(b"*" * 1024)
tmp_file.flush() tmp_file.flush()
tmp_file.close() tmp_file.close()
yield tmp_file.name yield tmp_file
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -54,7 +54,7 @@ class TestSettings:
"socks_port": 9999, "socks_port": 9999,
"use_stealth": True, "use_stealth": True,
} }
tmp_file, tmp_file_path = tempfile.mkstemp(dir=temp_dir) tmp_file, tmp_file_path = tempfile.mkstemp(dir=temp_dir.name)
with open(tmp_file, "w") as f: with open(tmp_file, "w") as f:
json.dump(custom_settings, f) json.dump(custom_settings, f)
settings_obj.filename = tmp_file_path settings_obj.filename = tmp_file_path
@ -69,7 +69,7 @@ class TestSettings:
def test_save(self, monkeypatch, temp_dir, settings_obj): def test_save(self, monkeypatch, temp_dir, settings_obj):
settings_filename = "default_settings.json" settings_filename = "default_settings.json"
new_temp_dir = tempfile.mkdtemp(dir=temp_dir) new_temp_dir = tempfile.mkdtemp(dir=temp_dir.name)
settings_path = os.path.join(new_temp_dir, settings_filename) settings_path = os.path.join(new_temp_dir, settings_filename)
settings_obj.filename = settings_path settings_obj.filename = settings_path
settings_obj.save() settings_obj.save()

View File

@ -50,7 +50,8 @@ def web_obj(temp_dir, common_obj, mode, num_files=0):
web = Web(common_obj, False, mode_settings, mode) web = Web(common_obj, False, mode_settings, mode)
web.running = True web.running = True
web.cleanup_filenames == [] web.cleanup_tempfiles == []
web.cleanup_tempdirs == []
web.app.testing = True web.app.testing = True
# Share mode # Share mode
@ -58,7 +59,9 @@ def web_obj(temp_dir, common_obj, mode, num_files=0):
# Add files # Add files
files = [] files = []
for _ in range(num_files): for _ in range(num_files):
with tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) as tmp_file: with tempfile.NamedTemporaryFile(
delete=False, dir=temp_dir.name
) as tmp_file:
tmp_file.write(b"*" * 1024) tmp_file.write(b"*" * 1024)
files.append(tmp_file.name) files.append(tmp_file.name)
web.share_mode.set_file_info(files) web.share_mode.set_file_info(files)
@ -131,7 +134,9 @@ class TestWeb:
with web.app.test_client() as c: with web.app.test_client() as c:
# Load / with valid auth # Load / with valid auth
res = c.get("/",) res = c.get(
"/",
)
res.get_data() res.get_data()
assert res.status_code == 200 assert res.status_code == 200
@ -169,7 +174,7 @@ class TestWeb:
def test_receive_mode_message_no_files(self, temp_dir, common_obj): def test_receive_mode_message_no_files(self, temp_dir, common_obj):
web = web_obj(temp_dir, common_obj, "receive") web = web_obj(temp_dir, common_obj, "receive")
data_dir = os.path.join(temp_dir, "OnionShare") data_dir = os.path.join(temp_dir.name, "OnionShare")
os.makedirs(data_dir, exist_ok=True) os.makedirs(data_dir, exist_ok=True)
web.settings.set("receive", "data_dir", data_dir) web.settings.set("receive", "data_dir", data_dir)
@ -200,7 +205,7 @@ class TestWeb:
def test_receive_mode_message_and_files(self, temp_dir, common_obj): def test_receive_mode_message_and_files(self, temp_dir, common_obj):
web = web_obj(temp_dir, common_obj, "receive") web = web_obj(temp_dir, common_obj, "receive")
data_dir = os.path.join(temp_dir, "OnionShare") data_dir = os.path.join(temp_dir.name, "OnionShare")
os.makedirs(data_dir, exist_ok=True) os.makedirs(data_dir, exist_ok=True)
web.settings.set("receive", "data_dir", data_dir) web.settings.set("receive", "data_dir", data_dir)
@ -235,7 +240,7 @@ class TestWeb:
def test_receive_mode_files_no_message(self, temp_dir, common_obj): def test_receive_mode_files_no_message(self, temp_dir, common_obj):
web = web_obj(temp_dir, common_obj, "receive") web = web_obj(temp_dir, common_obj, "receive")
data_dir = os.path.join(temp_dir, "OnionShare") data_dir = os.path.join(temp_dir.name, "OnionShare")
os.makedirs(data_dir, exist_ok=True) os.makedirs(data_dir, exist_ok=True)
web.settings.set("receive", "data_dir", data_dir) web.settings.set("receive", "data_dir", data_dir)
@ -267,7 +272,7 @@ class TestWeb:
def test_receive_mode_no_message_no_files(self, temp_dir, common_obj): def test_receive_mode_no_message_no_files(self, temp_dir, common_obj):
web = web_obj(temp_dir, common_obj, "receive") web = web_obj(temp_dir, common_obj, "receive")
data_dir = os.path.join(temp_dir, "OnionShare") data_dir = os.path.join(temp_dir.name, "OnionShare")
os.makedirs(data_dir, exist_ok=True) os.makedirs(data_dir, exist_ok=True)
web.settings.set("receive", "data_dir", data_dir) web.settings.set("receive", "data_dir", data_dir)
@ -300,15 +305,21 @@ class TestWeb:
res.get_data() res.get_data()
assert res.status_code == 200 assert res.status_code == 200
def test_cleanup(self, common_obj, temp_dir_1024, temp_file_1024): def test_cleanup(self, common_obj, temp_dir_1024):
web = web_obj(temp_dir_1024, common_obj, "share", 3) web = web_obj(temp_dir_1024, common_obj, "share", 3)
web.cleanup_filenames = [temp_dir_1024, temp_file_1024] temp_file = tempfile.NamedTemporaryFile()
temp_dir = tempfile.TemporaryDirectory()
web.cleanup_tempfiles = [temp_file]
web.cleanup_tempdirs = [temp_dir]
web.cleanup() web.cleanup()
assert os.path.exists(temp_file_1024) is False assert os.path.exists(temp_file.name) is False
assert os.path.exists(temp_dir_1024) is False assert os.path.exists(temp_dir.name) is False
assert web.cleanup_filenames == []
assert web.cleanup_tempfiles == []
assert web.cleanup_tempdirs == []
class TestZipWriterDefault: class TestZipWriterDefault:
@ -339,8 +350,10 @@ class TestZipWriterDefault:
assert default_zw.processed_size_callback(None) is None assert default_zw.processed_size_callback(None) is None
def test_add_file(self, default_zw, temp_file_1024_delete): def test_add_file(self, default_zw, temp_file_1024_delete):
default_zw.add_file(temp_file_1024_delete) default_zw.add_file(temp_file_1024_delete.name)
zipfile_info = default_zw.z.getinfo(os.path.basename(temp_file_1024_delete)) zipfile_info = default_zw.z.getinfo(
os.path.basename(temp_file_1024_delete.name)
)
assert zipfile_info.compress_type == zipfile.ZIP_DEFLATED assert zipfile_info.compress_type == zipfile.ZIP_DEFLATED
assert zipfile_info.file_size == 1024 assert zipfile_info.file_size == 1024
@ -568,7 +581,6 @@ class TestRangeRequests:
resp = client.get(url, headers=headers) resp = client.get(url, headers=headers)
assert resp.status_code == 206 assert resp.status_code == 206
@pytest.mark.skipif(sys.platform != "linux", reason="requires Linux") @pytest.mark.skipif(sys.platform != "linux", reason="requires Linux")
@check_unsupported("curl", ["--version"]) @check_unsupported("curl", ["--version"])
def test_curl(self, temp_dir, tmpdir, common_obj): def test_curl(self, temp_dir, tmpdir, common_obj):