diff --git a/Colab-TextGen-GPU.ipynb b/Colab-TextGen-GPU.ipynb index bb287702..50a6fbd1 100644 --- a/Colab-TextGen-GPU.ipynb +++ b/Colab-TextGen-GPU.ipynb @@ -70,9 +70,9 @@ "\n", " textgen_requirements = open('requirements.txt').read().splitlines()\n", " if is_cuda117:\n", - " textgen_requirements = [req.replace('+cu121', '+cu117').replace('torch2.1', 'torch2.0') for req in textgen_requirements]\n", + " textgen_requirements = [req.replace('+cu121', '+cu117').replace('+cu122', '+cu117').replace('torch2.1', 'torch2.0') for req in textgen_requirements]\n", " elif is_cuda118:\n", - " textgen_requirements = [req.replace('+cu121', '+cu118') for req in textgen_requirements]\n", + " textgen_requirements = [req.replace('+cu121', '+cu118').replace('+cu122', '+cu118') for req in textgen_requirements]\n", " with open('temp_requirements.txt', 'w') as file:\n", " file.write('\\n'.join(textgen_requirements))\n", "\n", diff --git a/one_click.py b/one_click.py index 52c1b752..8fc5cfce 100644 --- a/one_click.py +++ b/one_click.py @@ -174,7 +174,7 @@ def install_webui(): use_cuda118 = "N" if any((is_windows(), is_linux())) and choice == "A": if "USE_CUDA118" in os.environ: - use_cuda118 = os.environ.get("USE_CUDA118", "").lower() in ("yes", "y", "trye", "1", "t", "on") + use_cuda118 = "Y" if os.environ.get("USE_CUDA118", "").lower() in ("yes", "y", "true", "1", "t", "on") else "N" else: # Ask for CUDA version if using NVIDIA print("\nWould you like to use CUDA 11.8 instead of 12.1? This is only necessary for older GPUs like Kepler.\nIf unsure, say \"N\".\n") @@ -182,10 +182,10 @@ def install_webui(): while use_cuda118 not in 'YN': print("Invalid choice. Please try again.") use_cuda118 = input("Input> ").upper().strip('"\'').strip() - if use_cuda118 == 'Y': - print(f"CUDA: 11.8") - else: - print(f"CUDA: 12.1") + if use_cuda118 == 'Y': + print("CUDA: 11.8") + else: + print("CUDA: 12.1") install_pytorch = f"python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/{'cu121' if use_cuda118 == 'N' else 'cu118'}" elif not is_macos() and choice == "B": @@ -204,7 +204,7 @@ def install_webui(): # Install CUDA libraries (this wasn't necessary for Pytorch before...) if choice == "A": - run_cmd(f"conda install -y -c \"nvidia/label/{'cuda-12.1.0' if use_cuda118 == 'N' else 'cuda-11.8.0'}\" cuda-runtime", assert_success=True, environment=True) + run_cmd(f"conda install -y -c \"nvidia/label/{'cuda-12.1.1' if use_cuda118 == 'N' else 'cuda-11.8.0'}\" cuda-runtime", assert_success=True, environment=True) # Install the webui requirements update_requirements(initial_installation=True) @@ -286,9 +286,9 @@ def update_requirements(initial_installation=False): print_big_message(f"Installing webui requirements from file: {requirements_file}") textgen_requirements = open(requirements_file).read().splitlines() if is_cuda117: - textgen_requirements = [req.replace('+cu121', '+cu117').replace('torch2.1', 'torch2.0') for req in textgen_requirements] + textgen_requirements = [req.replace('+cu121', '+cu117').replace('+cu122', '+cu117').replace('torch2.1', 'torch2.0') for req in textgen_requirements] elif is_cuda118: - textgen_requirements = [req.replace('+cu121', '+cu118') for req in textgen_requirements] + textgen_requirements = [req.replace('+cu121', '+cu118').replace('+cu122', '+cu118') for req in textgen_requirements] with open('temp_requirements.txt', 'w') as file: file.write('\n'.join(textgen_requirements))