diff --git a/one_click.py b/one_click.py index a9ed8d15..d2516c5e 100644 --- a/one_click.py +++ b/one_click.py @@ -53,7 +53,16 @@ def cpu_has_avx2(): def torch_version(): - from torch import __version__ as torver + for sitedir in site.getsitepackages(): + if "site-packages" in sitedir and conda_env_path in sitedir: + site_packages_path = sitedir + break + + if site_packages_path: + torch_version_file = open(os.path.join(site_packages_path, 'torch', 'version.py')).read().splitlines() + torver = [line for line in torch_version_file if '__version__' in line][0].split('__version__ = ')[1].strip("'") + else: + from torch import __version__ as torver return torver