Fix tests after refactoring Common

This commit is contained in:
Micah Lee 2018-03-13 02:22:26 -07:00
parent 50409167d4
commit c2fecf8aa4
No known key found for this signature in database
GPG key ID: 403C2657CD994F73
7 changed files with 73 additions and 79 deletions

View file

@ -29,8 +29,6 @@ import zipfile
import pytest
from onionshare import common
LOG_MSG_REGEX = re.compile(r"""
^\[Jun\ 06\ 2013\ 11:05:00\]
\ TestModule\.<function\ TestLog\.test_output\.<locals>\.dummy_func
@ -38,6 +36,9 @@ LOG_MSG_REGEX = re.compile(r"""
SLUG_REGEX = re.compile(r'^([a-z]+)(-[a-z]+)?-([a-z]+)(-[a-z]+)?$')
# TODO: Improve the Common tests to test it all as a single class
class TestBuildSlug:
@pytest.mark.parametrize('test_input,expected', (
# VALID, two lowercase words, separated by a hyphen
@ -77,17 +78,17 @@ class TestBuildSlug:
assert bool(SLUG_REGEX.match(test_input)) == expected
def test_build_slug_unique(self, sys_onionshare_dev_mode):
assert common.build_slug() != common.build_slug()
def test_build_slug_unique(self, common_obj, sys_onionshare_dev_mode):
assert common_obj.build_slug() != common_obj.build_slug()
class TestDirSize:
def test_temp_dir_size(self, temp_dir_1024_delete):
def test_temp_dir_size(self, common_obj, temp_dir_1024_delete):
""" dir_size() should return the total size (in bytes) of all files
in a particular directory.
"""
assert common.dir_size(temp_dir_1024_delete) == 1024
assert common_obj.dir_size(temp_dir_1024_delete) == 1024
class TestEstimatedTimeRemaining:
@ -101,16 +102,16 @@ class TestEstimatedTimeRemaining:
((971, 1009, 83), '1s')
))
def test_estimated_time_remaining(
self, test_input, expected, time_time_100):
assert common.estimated_time_remaining(*test_input) == expected
self, common_obj, test_input, expected, time_time_100):
assert common_obj.estimated_time_remaining(*test_input) == expected
@pytest.mark.parametrize('test_input', (
(10, 20, 100), # if `time_elapsed == 0`
(0, 37, 99) # if `download_rate == 0`
))
def test_raises_zero_division_error(self, test_input, time_time_100):
def test_raises_zero_division_error(self, common_obj, test_input, time_time_100):
with pytest.raises(ZeroDivisionError):
common.estimated_time_remaining(*test_input)
common_obj.estimated_time_remaining(*test_input)
class TestFormatSeconds:
@ -129,16 +130,16 @@ class TestFormatSeconds:
(129674, '1d12h1m14s'),
(56404.12, '15h40m4s')
))
def test_format_seconds(self, test_input, expected):
assert common.format_seconds(test_input) == expected
def test_format_seconds(self, common_obj, test_input, expected):
assert common_obj.format_seconds(test_input) == expected
# TODO: test negative numbers?
@pytest.mark.parametrize('test_input', (
'string', lambda: None, [], {}, set()
))
def test_invalid_input_types(self, test_input):
def test_invalid_input_types(self, common_obj, test_input):
with pytest.raises(TypeError):
common.format_seconds(test_input)
common_obj.format_seconds(test_input)
class TestGetAvailablePort:
@ -146,29 +147,29 @@ class TestGetAvailablePort:
(random.randint(1024, 1500),
random.randint(1800, 2048)) for _ in range(50)
))
def test_returns_an_open_port(self, port_min, port_max):
def test_returns_an_open_port(self, common_obj, port_min, port_max):
""" get_available_port() should return an open port within the range """
port = common.get_available_port(port_min, port_max)
port = common_obj.get_available_port(port_min, port_max)
assert port_min <= port <= port_max
with socket.socket() as tmpsock:
tmpsock.bind(('127.0.0.1', port))
class TestGetPlatform:
def test_darwin(self, platform_darwin):
assert common.platform == 'Darwin'
def test_darwin(self, platform_darwin, common_obj):
assert common_obj.platform == 'Darwin'
def test_linux(self, platform_linux):
assert common.platform == 'Linux'
def test_linux(self, platform_linux, common_obj):
assert common_obj.platform == 'Linux'
def test_windows(self, platform_windows):
assert common.platform == 'Windows'
def test_windows(self, platform_windows, common_obj):
assert common_obj.platform == 'Windows'
# TODO: double-check these tests
class TestGetResourcePath:
def test_onionshare_dev_mode(self, sys_onionshare_dev_mode):
def test_onionshare_dev_mode(self, common_obj, sys_onionshare_dev_mode):
prefix = os.path.join(
os.path.dirname(
os.path.dirname(
@ -176,29 +177,29 @@ class TestGetResourcePath:
inspect.getfile(
inspect.currentframe())))), 'share')
assert (
common.get_resource_path(os.path.join(prefix, 'test_filename')) ==
common_obj.get_resource_path(os.path.join(prefix, 'test_filename')) ==
os.path.join(prefix, 'test_filename'))
def test_linux(self, platform_linux, sys_argv_sys_prefix):
def test_linux(self, common_obj, platform_linux, sys_argv_sys_prefix):
prefix = os.path.join(sys.prefix, 'share/onionshare')
assert (
common.get_resource_path(os.path.join(prefix, 'test_filename')) ==
common_obj.get_resource_path(os.path.join(prefix, 'test_filename')) ==
os.path.join(prefix, 'test_filename'))
def test_frozen_darwin(self, platform_darwin, sys_frozen, sys_meipass):
def test_frozen_darwin(self, common_obj, platform_darwin, sys_frozen, sys_meipass):
prefix = os.path.join(sys._MEIPASS, 'share')
assert (
common.get_resource_path(os.path.join(prefix, 'test_filename')) ==
common_obj.get_resource_path(os.path.join(prefix, 'test_filename')) ==
os.path.join(prefix, 'test_filename'))
class TestGetTorPaths:
# @pytest.mark.skipif(sys.platform != 'Darwin', reason='requires MacOS') ?
def test_get_tor_paths_darwin(self, platform_darwin, sys_frozen, sys_meipass):
def test_get_tor_paths_darwin(self, platform_darwin, common_obj, sys_frozen, sys_meipass):
base_path = os.path.dirname(
os.path.dirname(
os.path.dirname(
common.get_resource_path(''))))
common_obj.get_resource_path(''))))
tor_path = os.path.join(
base_path, 'Resources', 'Tor', 'tor')
tor_geo_ip_file_path = os.path.join(
@ -207,20 +208,20 @@ class TestGetTorPaths:
base_path, 'Resources', 'Tor', 'geoip6')
obfs4proxy_file_path = os.path.join(
base_path, 'Resources', 'Tor', 'obfs4proxy')
assert (common.get_tor_paths() ==
assert (common_obj.get_tor_paths() ==
(tor_path, tor_geo_ip_file_path, tor_geo_ipv6_file_path, obfs4proxy_file_path))
# @pytest.mark.skipif(sys.platform != 'Linux', reason='requires Linux') ?
def test_get_tor_paths_linux(self, platform_linux):
assert (common.get_tor_paths() ==
def test_get_tor_paths_linux(self, platform_linux, common_obj):
assert (common_obj.get_tor_paths() ==
('/usr/bin/tor', '/usr/share/tor/geoip', '/usr/share/tor/geoip6', '/usr/bin/obfs4proxy'))
# @pytest.mark.skipif(sys.platform != 'Windows', reason='requires Windows') ?
def test_get_tor_paths_windows(self, platform_windows, sys_frozen):
def test_get_tor_paths_windows(self, platform_windows, common_obj, sys_frozen):
base_path = os.path.join(
os.path.dirname(
os.path.dirname(
common.get_resource_path(''))), 'tor')
common_obj.get_resource_path(''))), 'tor')
tor_path = os.path.join(
os.path.join(base_path, 'Tor'), 'tor.exe')
obfs4proxy_file_path = os.path.join(
@ -231,7 +232,7 @@ class TestGetTorPaths:
tor_geo_ipv6_file_path = os.path.join(
os.path.join(
os.path.join(base_path, 'Data'), 'Tor'), 'geoip6')
assert (common.get_tor_paths() ==
assert (common_obj.get_tor_paths() ==
(tor_path, tor_geo_ip_file_path, tor_geo_ipv6_file_path, obfs4proxy_file_path))
@ -247,8 +248,8 @@ class TestHumanReadableFilesize:
(1024 ** 7, '1.0 ZiB'),
(1024 ** 8, '1.0 YiB')
))
def test_human_readable_filesize(self, test_input, expected):
assert common.human_readable_filesize(test_input) == expected
def test_human_readable_filesize(self, common_obj, test_input, expected):
assert common_obj.human_readable_filesize(test_input) == expected
class TestLog:
@ -263,14 +264,16 @@ class TestLog:
def test_log_msg_regex(self, test_input):
assert bool(LOG_MSG_REGEX.match(test_input))
def test_output(self, set_debug_true, time_strftime):
def test_output(self, common_obj, time_strftime):
def dummy_func():
pass
common_obj.debug = True
# From: https://stackoverflow.com/questions/1218933
with io.StringIO() as buf, contextlib.redirect_stdout(buf):
common.log('TestModule', dummy_func)
common.log('TestModule', dummy_func, 'TEST_MSG')
common_obj.log('TestModule', dummy_func)
common_obj.log('TestModule', dummy_func, 'TEST_MSG')
output = buf.getvalue()
line_one, line_two, _ = output.split('\n')