USE_CUDA118 from ENV remains null one_click.py + cuda-toolkit (#4352)

This commit is contained in:
mongolu 2023-10-22 18:37:24 +03:00 committed by GitHub
parent cd45635f53
commit c18504f369
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 10 deletions

View File

@ -70,9 +70,9 @@
"\n", "\n",
" textgen_requirements = open('requirements.txt').read().splitlines()\n", " textgen_requirements = open('requirements.txt').read().splitlines()\n",
" if is_cuda117:\n", " if is_cuda117:\n",
" textgen_requirements = [req.replace('+cu121', '+cu117').replace('torch2.1', 'torch2.0') for req in textgen_requirements]\n", " textgen_requirements = [req.replace('+cu121', '+cu117').replace('+cu122', '+cu117').replace('torch2.1', 'torch2.0') for req in textgen_requirements]\n",
" elif is_cuda118:\n", " elif is_cuda118:\n",
" textgen_requirements = [req.replace('+cu121', '+cu118') for req in textgen_requirements]\n", " textgen_requirements = [req.replace('+cu121', '+cu118').replace('+cu122', '+cu118') for req in textgen_requirements]\n",
" with open('temp_requirements.txt', 'w') as file:\n", " with open('temp_requirements.txt', 'w') as file:\n",
" file.write('\\n'.join(textgen_requirements))\n", " file.write('\\n'.join(textgen_requirements))\n",
"\n", "\n",

View File

@ -174,7 +174,7 @@ def install_webui():
use_cuda118 = "N" use_cuda118 = "N"
if any((is_windows(), is_linux())) and choice == "A": if any((is_windows(), is_linux())) and choice == "A":
if "USE_CUDA118" in os.environ: if "USE_CUDA118" in os.environ:
use_cuda118 = os.environ.get("USE_CUDA118", "").lower() in ("yes", "y", "trye", "1", "t", "on") use_cuda118 = "Y" if os.environ.get("USE_CUDA118", "").lower() in ("yes", "y", "true", "1", "t", "on") else "N"
else: else:
# Ask for CUDA version if using NVIDIA # Ask for CUDA version if using NVIDIA
print("\nWould you like to use CUDA 11.8 instead of 12.1? This is only necessary for older GPUs like Kepler.\nIf unsure, say \"N\".\n") print("\nWould you like to use CUDA 11.8 instead of 12.1? This is only necessary for older GPUs like Kepler.\nIf unsure, say \"N\".\n")
@ -183,9 +183,9 @@ def install_webui():
print("Invalid choice. Please try again.") print("Invalid choice. Please try again.")
use_cuda118 = input("Input> ").upper().strip('"\'').strip() use_cuda118 = input("Input> ").upper().strip('"\'').strip()
if use_cuda118 == 'Y': if use_cuda118 == 'Y':
print(f"CUDA: 11.8") print("CUDA: 11.8")
else: else:
print(f"CUDA: 12.1") print("CUDA: 12.1")
install_pytorch = f"python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/{'cu121' if use_cuda118 == 'N' else 'cu118'}" install_pytorch = f"python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/{'cu121' if use_cuda118 == 'N' else 'cu118'}"
elif not is_macos() and choice == "B": elif not is_macos() and choice == "B":
@ -204,7 +204,7 @@ def install_webui():
# Install CUDA libraries (this wasn't necessary for Pytorch before...) # Install CUDA libraries (this wasn't necessary for Pytorch before...)
if choice == "A": if choice == "A":
run_cmd(f"conda install -y -c \"nvidia/label/{'cuda-12.1.0' if use_cuda118 == 'N' else 'cuda-11.8.0'}\" cuda-runtime", assert_success=True, environment=True) run_cmd(f"conda install -y -c \"nvidia/label/{'cuda-12.1.1' if use_cuda118 == 'N' else 'cuda-11.8.0'}\" cuda-runtime", assert_success=True, environment=True)
# Install the webui requirements # Install the webui requirements
update_requirements(initial_installation=True) update_requirements(initial_installation=True)
@ -286,9 +286,9 @@ def update_requirements(initial_installation=False):
print_big_message(f"Installing webui requirements from file: {requirements_file}") print_big_message(f"Installing webui requirements from file: {requirements_file}")
textgen_requirements = open(requirements_file).read().splitlines() textgen_requirements = open(requirements_file).read().splitlines()
if is_cuda117: if is_cuda117:
textgen_requirements = [req.replace('+cu121', '+cu117').replace('torch2.1', 'torch2.0') for req in textgen_requirements] textgen_requirements = [req.replace('+cu121', '+cu117').replace('+cu122', '+cu117').replace('torch2.1', 'torch2.0') for req in textgen_requirements]
elif is_cuda118: elif is_cuda118:
textgen_requirements = [req.replace('+cu121', '+cu118') for req in textgen_requirements] textgen_requirements = [req.replace('+cu121', '+cu118').replace('+cu122', '+cu118') for req in textgen_requirements]
with open('temp_requirements.txt', 'w') as file: with open('temp_requirements.txt', 'w') as file:
file.write('\n'.join(textgen_requirements)) file.write('\n'.join(textgen_requirements))