2023-09-21 18:35:53 -04:00
import argparse
import glob
2023-10-06 23:23:49 -04:00
import hashlib
2023-09-21 18:35:53 -04:00
import os
2023-09-24 08:58:29 -04:00
import platform
2023-09-26 09:56:57 -04:00
import re
2023-12-05 00:16:16 -05:00
import signal
2023-09-21 23:12:16 -04:00
import site
2023-09-23 09:48:09 -04:00
import subprocess
2023-09-21 18:35:53 -04:00
import sys
2023-09-22 11:02:21 -04:00
# Remove the '# ' from the following lines as needed for your AMD GPU on Linux
# os.environ["ROCM_PATH"] = '/opt/rocm'
# os.environ["HSA_OVERRIDE_GFX_VERSION"] = '10.3.0'
# os.environ["HCC_AMDGPU_TARGET"] = 'gfx1030'
2024-03-03 17:40:32 -05:00
# Define the required PyTorch version
TORCH_VERSION = " 2.2.1 "
TORCHVISION_VERSION = " 0.17.1 "
TORCHAUDIO_VERSION = " 2.2.1 "
# Environment
script_dir = os . getcwd ( )
conda_env_path = os . path . join ( script_dir , " installer_files " , " env " )
2023-09-21 18:35:53 -04:00
# Command-line flags
2023-09-22 13:52:52 -04:00
cmd_flags_path = os . path . join ( script_dir , " CMD_FLAGS.txt " )
if os . path . exists ( cmd_flags_path ) :
with open ( cmd_flags_path , ' r ' ) as f :
2023-11-16 12:33:36 -05:00
CMD_FLAGS = ' ' . join ( line . strip ( ) . rstrip ( ' \\ ' ) . strip ( ) for line in f if line . strip ( ) . rstrip ( ' \\ ' ) . strip ( ) and not line . strip ( ) . startswith ( ' # ' ) )
2023-09-21 18:35:53 -04:00
else :
2023-09-22 13:52:52 -04:00
CMD_FLAGS = ' '
2023-09-21 18:35:53 -04:00
2024-03-04 13:52:24 -05:00
flags = f " { ' ' . join ( [ flag for flag in sys . argv [ 1 : ] if flag != ' --update-wizard ' ] ) } { CMD_FLAGS } "
2023-09-21 18:35:53 -04:00
2023-09-28 16:56:15 -04:00
2023-12-05 00:16:16 -05:00
def signal_handler ( sig , frame ) :
sys . exit ( 0 )
signal . signal ( signal . SIGINT , signal_handler )
2023-09-22 11:02:21 -04:00
def is_linux ( ) :
return sys . platform . startswith ( " linux " )
def is_windows ( ) :
return sys . platform . startswith ( " win " )
def is_macos ( ) :
return sys . platform . startswith ( " darwin " )
2023-09-24 08:58:29 -04:00
def is_x86_64 ( ) :
return platform . machine ( ) == " x86_64 "
def cpu_has_avx2 ( ) :
2023-09-24 11:10:45 -04:00
try :
import cpuinfo
2023-09-24 08:58:29 -04:00
2023-09-24 11:10:45 -04:00
info = cpuinfo . get_cpu_info ( )
if ' avx2 ' in info [ ' flags ' ] :
return True
else :
return False
except :
2023-09-24 08:58:29 -04:00
return True
2023-10-26 22:39:51 -04:00
def cpu_has_amx ( ) :
try :
import cpuinfo
info = cpuinfo . get_cpu_info ( )
if ' amx ' in info [ ' flags ' ] :
return True
else :
return False
except :
return True
2023-09-24 08:58:29 -04:00
def torch_version ( ) :
2023-09-28 07:31:29 -04:00
site_packages_path = None
2023-09-24 21:16:59 -04:00
for sitedir in site . getsitepackages ( ) :
if " site-packages " in sitedir and conda_env_path in sitedir :
site_packages_path = sitedir
break
if site_packages_path :
torch_version_file = open ( os . path . join ( site_packages_path , ' torch ' , ' version.py ' ) ) . read ( ) . splitlines ( )
2024-03-03 17:40:32 -05:00
torver = [ line for line in torch_version_file if line . startswith ( ' __version__ ' ) ] [ 0 ] . split ( ' __version__ = ' ) [ 1 ] . strip ( " ' " )
2023-09-24 21:16:59 -04:00
else :
from torch import __version__ as torver
2024-01-04 21:50:23 -05:00
2023-09-24 08:58:29 -04:00
return torver
2024-03-03 17:40:32 -05:00
def update_pytorch ( ) :
print_big_message ( " Checking for PyTorch updates " )
torver = torch_version ( )
is_cuda = ' +cu ' in torver
is_cuda118 = ' +cu118 ' in torver # 2.1.0+cu118
is_rocm = ' +rocm ' in torver # 2.0.1+rocm5.4.2
is_intel = ' +cxx11 ' in torver # 2.0.1a0+cxx11.abi
is_cpu = ' +cpu ' in torver # 2.0.1+cpu
install_pytorch = f " python -m pip install --upgrade torch== { TORCH_VERSION } torchvision== { TORCHVISION_VERSION } torchaudio== { TORCHAUDIO_VERSION } "
if is_cuda118 :
install_pytorch + = " --index-url https://download.pytorch.org/whl/cu118 "
elif is_cuda :
install_pytorch + = " --index-url https://download.pytorch.org/whl/cu121 "
elif is_rocm :
install_pytorch + = " --index-url https://download.pytorch.org/whl/rocm5.6 "
elif is_cpu :
install_pytorch + = " --index-url https://download.pytorch.org/whl/cpu "
elif is_intel :
if is_linux ( ) :
install_pytorch = " python -m pip install --upgrade torch==2.1.0a0 torchvision==0.16.0a0 torchaudio==2.1.0a0 intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ "
else :
install_pytorch = " python -m pip install --upgrade torch==2.1.0a0 torchvision==0.16.0a0 torchaudio==2.1.0a0 intel-extension-for-pytorch==2.1.10 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ "
run_cmd ( f " { install_pytorch } " , assert_success = True , environment = True )
2023-09-22 11:02:21 -04:00
def is_installed ( ) :
2023-09-28 07:31:29 -04:00
site_packages_path = None
2023-09-22 11:02:21 -04:00
for sitedir in site . getsitepackages ( ) :
if " site-packages " in sitedir and conda_env_path in sitedir :
site_packages_path = sitedir
break
if site_packages_path :
return os . path . isfile ( os . path . join ( site_packages_path , ' torch ' , ' __init__.py ' ) )
else :
return os . path . isdir ( conda_env_path )
def check_env ( ) :
# If we have access to conda, we are probably in an environment
conda_exist = run_cmd ( " conda " , environment = True , capture_output = True ) . returncode == 0
if not conda_exist :
print ( " Conda is not installed. Exiting... " )
2023-09-28 16:56:15 -04:00
sys . exit ( 1 )
2023-09-22 11:02:21 -04:00
# Ensure this is a new environment and not the base environment
if os . environ [ " CONDA_DEFAULT_ENV " ] == " base " :
print ( " Create an environment for this project and activate it. Exiting... " )
2023-09-28 16:56:15 -04:00
sys . exit ( 1 )
2023-09-22 11:02:21 -04:00
def clear_cache ( ) :
run_cmd ( " conda clean -a -y " , environment = True )
run_cmd ( " python -m pip cache purge " , environment = True )
2023-09-21 18:35:53 -04:00
def print_big_message ( message ) :
message = message . strip ( )
lines = message . split ( ' \n ' )
print ( " \n \n ******************************************************************* " )
for line in lines :
2024-03-04 11:20:04 -05:00
print ( " * " , line )
2023-09-21 18:35:53 -04:00
print ( " ******************************************************************* \n \n " )
2023-10-06 23:23:49 -04:00
def calculate_file_hash ( file_path ) :
p = os . path . join ( script_dir , file_path )
if os . path . isfile ( p ) :
with open ( p , ' rb ' ) as f :
return hashlib . sha256 ( f . read ( ) ) . hexdigest ( )
else :
return ' '
2023-09-21 18:35:53 -04:00
def run_cmd ( cmd , assert_success = False , environment = False , capture_output = False , env = None ) :
# Use the conda environment
if environment :
2023-09-22 11:02:21 -04:00
if is_windows ( ) :
2023-09-21 18:35:53 -04:00
conda_bat_path = os . path . join ( script_dir , " installer_files " , " conda " , " condabin " , " conda.bat " )
2024-01-27 15:31:22 -05:00
cmd = f ' " { conda_bat_path } " activate " { conda_env_path } " >nul && { cmd } '
2023-09-21 18:35:53 -04:00
else :
conda_sh_path = os . path . join ( script_dir , " installer_files " , " conda " , " etc " , " profile.d " , " conda.sh " )
2024-01-27 15:31:22 -05:00
cmd = f ' . " { conda_sh_path } " && conda activate " { conda_env_path } " && { cmd } '
2023-09-21 18:35:53 -04:00
# Run shell commands
result = subprocess . run ( cmd , shell = True , capture_output = capture_output , env = env )
# Assert the command ran successfully
if assert_success and result . returncode != 0 :
2024-01-27 15:31:22 -05:00
print ( f " Command ' { cmd } ' failed with exit status code ' { str ( result . returncode ) } ' . \n \n Exiting now. \n Try running the start/update script again. " )
2023-09-28 16:56:15 -04:00
sys . exit ( 1 )
2023-09-21 18:35:53 -04:00
return result
2024-03-06 10:36:23 -05:00
def generate_alphabetic_sequence ( index ) :
result = ' '
while index > = 0 :
index , remainder = divmod ( index , 26 )
result = chr ( ord ( ' A ' ) + remainder ) + result
index - = 1
return result
2024-03-04 13:52:24 -05:00
def get_user_choice ( question , options_dict ) :
print ( )
print ( question )
print ( )
for key , value in options_dict . items ( ) :
print ( f " { key } ) { value } " )
print ( )
choice = input ( " Input> " ) . upper ( )
while choice not in options_dict . keys ( ) :
print ( " Invalid choice. Please try again. " )
choice = input ( " Input> " ) . upper ( )
return choice
2023-09-22 15:08:05 -04:00
def install_webui ( ) :
2024-03-03 17:40:32 -05:00
# Ask the user for the GPU vendor
2023-09-22 21:43:11 -04:00
if " GPU_CHOICE " in os . environ :
choice = os . environ [ " GPU_CHOICE " ] . upper ( )
print_big_message ( f " Selected GPU choice \" { choice } \" based on the GPU_CHOICE environment variable. " )
else :
2024-03-04 13:52:24 -05:00
choice = get_user_choice (
" What is your GPU? " ,
{
' A ' : ' NVIDIA ' ,
' B ' : ' AMD (Linux/MacOS only. Requires ROCm SDK 5.6 on Linux) ' ,
' C ' : ' Apple M Series ' ,
' D ' : ' Intel Arc (IPEX) ' ,
' N ' : ' None (I want to run models in CPU mode) '
} ,
)
2023-09-22 21:43:11 -04:00
2024-01-04 21:41:54 -05:00
gpu_choice_to_name = {
" A " : " NVIDIA " ,
" B " : " AMD " ,
" C " : " APPLE " ,
" D " : " INTEL " ,
" N " : " NONE "
}
selected_gpu = gpu_choice_to_name [ choice ]
2024-03-03 17:40:32 -05:00
use_cuda118 = " N "
2024-01-04 21:41:54 -05:00
2024-03-03 17:40:32 -05:00
# Write a flag to CMD_FLAGS.txt for CPU mode
2024-01-04 21:41:54 -05:00
if selected_gpu == " NONE " :
with open ( cmd_flags_path , ' r+ ' ) as cmd_flags_file :
if " --cpu " not in cmd_flags_file . read ( ) :
print_big_message ( " Adding the --cpu flag to CMD_FLAGS.txt. " )
2024-03-03 20:42:59 -05:00
cmd_flags_file . write ( " \n --cpu \n " )
2023-09-21 18:35:53 -04:00
2024-03-03 17:40:32 -05:00
# Check if the user wants CUDA 11.8
elif any ( ( is_windows ( ) , is_linux ( ) ) ) and selected_gpu == " NVIDIA " :
2023-10-21 02:46:23 -04:00
if " USE_CUDA118 " in os . environ :
2023-10-22 11:37:24 -04:00
use_cuda118 = " Y " if os . environ . get ( " USE_CUDA118 " , " " ) . lower ( ) in ( " yes " , " y " , " true " , " 1 " , " t " , " on " ) else " N "
2023-10-21 02:46:23 -04:00
else :
2024-03-04 22:26:24 -05:00
print ( " \n Do you want to use CUDA 11.8 instead of 12.1? \n Only choose this option if your GPU is very old (Kepler or older). \n \n For RTX and GTX series GPUs, say \" N \" . \n If unsure, say \" N \" . \n " )
2023-10-21 02:46:23 -04:00
use_cuda118 = input ( " Input (Y/N)> " ) . upper ( ) . strip ( ' " \' ' ) . strip ( )
while use_cuda118 not in ' YN ' :
print ( " Invalid choice. Please try again. " )
use_cuda118 = input ( " Input> " ) . upper ( ) . strip ( ' " \' ' ) . strip ( )
2024-01-04 21:41:54 -05:00
2023-10-22 11:37:24 -04:00
if use_cuda118 == ' Y ' :
print ( " CUDA: 11.8 " )
else :
print ( " CUDA: 12.1 " )
2024-03-03 17:40:32 -05:00
# No PyTorch for AMD on Windows (?)
elif is_windows ( ) and selected_gpu == " AMD " :
print ( " PyTorch setup on Windows is not implemented yet. Exiting... " )
sys . exit ( 1 )
# Find the Pytorch installation command
install_pytorch = f " python -m pip install torch== { TORCH_VERSION } torchvision== { TORCHVISION_VERSION } torchaudio== { TORCHAUDIO_VERSION } "
if selected_gpu == " NVIDIA " :
if use_cuda118 == ' Y ' :
install_pytorch + = " --index-url https://download.pytorch.org/whl/cu118 "
2023-09-21 18:35:53 -04:00
else :
2024-03-03 17:40:32 -05:00
install_pytorch + = " --index-url https://download.pytorch.org/whl/cu121 "
elif selected_gpu == " AMD " :
install_pytorch + = " --index-url https://download.pytorch.org/whl/rocm5.6 "
elif selected_gpu in [ " APPLE " , " NONE " ] :
2024-01-04 21:50:23 -05:00
install_pytorch + = " --index-url https://download.pytorch.org/whl/cpu "
2024-01-04 21:41:54 -05:00
elif selected_gpu == " INTEL " :
2024-03-03 17:40:32 -05:00
if is_linux ( ) :
install_pytorch = " python -m pip install torch==2.1.0a0 torchvision==0.16.0a0 torchaudio==2.1.0a0 intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ "
else :
install_pytorch = " python -m pip install torch==2.1.0a0 torchvision==0.16.0a0 torchaudio==2.1.0a0 intel-extension-for-pytorch==2.1.10 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ "
2023-09-22 13:51:21 -04:00
# Install Git and then Pytorch
2023-12-14 19:41:59 -05:00
print_big_message ( " Installing PyTorch. " )
2024-03-03 17:40:32 -05:00
run_cmd ( f " conda install -y -k ninja git && { install_pytorch } && python -m pip install py-cpuinfo==9.0.0 " , assert_success = True , environment = True )
2023-10-13 00:02:44 -04:00
2024-01-07 08:40:30 -05:00
if selected_gpu == " INTEL " :
# Install oneAPI dependencies via conda
print_big_message ( " Installing Intel oneAPI runtime libraries. " )
2024-01-07 12:30:55 -05:00
run_cmd ( " conda install -y -c intel dpcpp-cpp-rt=2024.0 mkl-dpcpp=2024.0 " )
2024-01-07 08:40:30 -05:00
# Install libuv required by Intel-patched torch
2024-01-07 12:30:55 -05:00
run_cmd ( " conda install -y libuv " )
2024-01-07 08:40:30 -05:00
2023-09-22 15:08:05 -04:00
# Install the webui requirements
update_requirements ( initial_installation = True )
2023-09-21 18:35:53 -04:00
2024-03-06 10:36:23 -05:00
def get_extensions_names ( ) :
return [ foldername for foldername in os . listdir ( ' extensions ' ) if os . path . isfile ( os . path . join ( ' extensions ' , foldername , ' requirements.txt ' ) ) ]
2024-03-04 02:46:39 -05:00
def install_extensions_requirements ( ) :
2024-03-04 23:37:44 -05:00
print_big_message ( " Installing extensions requirements. \n Some of these may fail on Windows. \n Don \' t worry if you see error messages, as they will not affect the main program. " )
2024-03-06 10:36:23 -05:00
extensions = get_extensions_names ( )
2024-03-04 02:46:39 -05:00
for i , extension in enumerate ( extensions ) :
print ( f " \n \n --- [ { i + 1 } / { len ( extensions ) } ]: { extension } \n \n " )
extension_req_path = os . path . join ( " extensions " , extension , " requirements.txt " )
run_cmd ( f " python -m pip install -r { extension_req_path } --upgrade " , assert_success = False , environment = True )
2024-03-04 15:35:41 -05:00
def update_requirements ( initial_installation = False , pull = True ) :
2023-09-21 22:51:58 -04:00
# Create .git directory if missing
2024-01-27 15:18:50 -05:00
if not os . path . exists ( os . path . join ( script_dir , " .git " ) ) :
2023-10-21 02:13:09 -04:00
git_creation_cmd = ' git init -b main && git remote add origin https://github.com/oobabooga/text-generation-webui && git fetch && git symbolic-ref refs/remotes/origin/HEAD refs/remotes/origin/main && git reset --hard origin/main && git branch --set-upstream-to=origin/main '
2023-09-21 22:51:58 -04:00
run_cmd ( git_creation_cmd , environment = True , assert_success = True )
2023-09-22 11:02:21 -04:00
2024-03-04 15:35:41 -05:00
if pull :
print_big_message ( " Updating the local copy of the repository with \" git pull \" " )
2023-10-06 23:23:49 -04:00
2024-03-04 15:35:41 -05:00
files_to_check = [
' start_linux.sh ' , ' start_macos.sh ' , ' start_windows.bat ' , ' start_wsl.bat ' ,
' update_linux.sh ' , ' update_macos.sh ' , ' update_windows.bat ' , ' update_wsl.bat ' ,
' one_click.py '
]
2024-03-04 11:00:39 -05:00
2024-03-04 15:35:41 -05:00
before_pull_hashes = { file_name : calculate_file_hash ( file_name ) for file_name in files_to_check }
run_cmd ( " git pull --autostash " , assert_success = True , environment = True )
after_pull_hashes = { file_name : calculate_file_hash ( file_name ) for file_name in files_to_check }
2023-10-06 23:23:49 -04:00
2024-03-04 15:35:41 -05:00
# Check for differences in installation file hashes
for file_name in files_to_check :
if before_pull_hashes [ file_name ] != after_pull_hashes [ file_name ] :
print_big_message ( f " File ' { file_name } ' was updated during ' git pull ' . Please run the script again. " )
exit ( 1 )
2023-09-21 18:35:53 -04:00
2024-03-06 14:31:06 -05:00
if os . environ . get ( " INSTALL_EXTENSIONS " , " " ) . lower ( ) in ( " yes " , " y " , " true " , " 1 " , " t " , " on " ) :
install_extensions_requirements ( )
2024-03-03 17:40:32 -05:00
# Update PyTorch
if not initial_installation :
update_pytorch ( )
# Detect the PyTorch version
2023-09-24 08:58:29 -04:00
torver = torch_version ( )
2023-10-21 02:46:23 -04:00
is_cuda = ' +cu ' in torver
is_cuda118 = ' +cu118 ' in torver # 2.1.0+cu118
2023-09-24 08:58:29 -04:00
is_rocm = ' +rocm ' in torver # 2.0.1+rocm5.4.2
2024-01-04 21:51:52 -05:00
is_intel = ' +cxx11 ' in torver # 2.0.1a0+cxx11.abi
2023-09-24 08:58:29 -04:00
is_cpu = ' +cpu ' in torver # 2.0.1+cpu
if is_rocm :
2024-01-04 21:41:54 -05:00
base_requirements = " requirements_amd " + ( " _noavx2 " if not cpu_has_avx2 ( ) else " " ) + " .txt "
2024-01-04 21:51:52 -05:00
elif is_cpu or is_intel :
2024-01-04 21:41:54 -05:00
base_requirements = " requirements_cpu_only " + ( " _noavx2 " if not cpu_has_avx2 ( ) else " " ) + " .txt "
2023-09-24 08:58:29 -04:00
elif is_macos ( ) :
2024-01-04 21:41:54 -05:00
base_requirements = " requirements_apple_ " + ( " intel " if is_x86_64 ( ) else " silicon " ) + " .txt "
2023-09-24 08:58:29 -04:00
else :
2024-01-04 21:41:54 -05:00
base_requirements = " requirements " + ( " _noavx2 " if not cpu_has_avx2 ( ) else " " ) + " .txt "
requirements_file = base_requirements
2023-09-24 08:58:29 -04:00
2023-09-28 17:27:25 -04:00
print_big_message ( f " Installing webui requirements from file: { requirements_file } " )
2023-12-14 20:27:32 -05:00
print ( f " TORCH: { torver } \n " )
2023-12-14 19:41:59 -05:00
# Prepare the requirements file
2023-09-24 08:58:29 -04:00
textgen_requirements = open ( requirements_file ) . read ( ) . splitlines ( )
2024-01-30 11:19:20 -05:00
if is_cuda118 :
2023-10-22 11:37:24 -04:00
textgen_requirements = [ req . replace ( ' +cu121 ' , ' +cu118 ' ) . replace ( ' +cu122 ' , ' +cu118 ' ) for req in textgen_requirements ]
2024-01-30 11:19:20 -05:00
if is_windows ( ) and is_cuda118 : # No flash-attention on Windows for CUDA 11
2023-11-21 18:06:56 -05:00
textgen_requirements = [ req for req in textgen_requirements if ' jllllll/flash-attention ' not in req ]
2023-10-25 14:21:18 -04:00
2023-10-06 23:23:49 -04:00
with open ( ' temp_requirements.txt ' , ' w ' ) as file :
file . write ( ' \n ' . join ( textgen_requirements ) )
2023-09-21 18:35:53 -04:00
2023-10-06 23:23:49 -04:00
# Workaround for git+ packages not updating properly.
2023-09-21 18:35:53 -04:00
git_requirements = [ req for req in textgen_requirements if req . startswith ( " git+ " ) ]
for req in git_requirements :
url = req . replace ( " git+ " , " " )
2023-10-06 23:23:49 -04:00
package_name = url . split ( " / " ) [ - 1 ] . split ( " @ " ) [ 0 ] . rstrip ( " .git " )
2024-01-27 15:31:22 -05:00
run_cmd ( f " python -m pip uninstall -y { package_name } " , environment = True )
2023-09-21 18:35:53 -04:00
print ( f " Uninstalled { package_name } " )
2023-09-22 15:08:05 -04:00
# Install/update the project requirements
2023-10-06 23:23:49 -04:00
run_cmd ( " python -m pip install -r temp_requirements.txt --upgrade " , assert_success = True , environment = True )
os . remove ( ' temp_requirements.txt ' )
2023-09-21 18:35:53 -04:00
2023-09-22 15:08:05 -04:00
# Check for '+cu' or '+rocm' in version string to determine if torch uses CUDA or ROCm. Check for pytorch-cuda as well for backwards compatibility
2023-09-23 09:28:58 -04:00
if not any ( ( is_cuda , is_rocm ) ) and run_cmd ( " conda list -f pytorch-cuda | grep pytorch-cuda " , environment = True , capture_output = True ) . returncode == 1 :
2023-09-21 18:35:53 -04:00
clear_cache ( )
return
if not os . path . exists ( " repositories/ " ) :
os . mkdir ( " repositories " )
clear_cache ( )
def launch_webui ( ) :
2023-09-26 09:56:57 -04:00
run_cmd ( f " python server.py { flags } " , environment = True )
2023-09-21 18:35:53 -04:00
if __name__ == " __main__ " :
# Verifies we are in a conda environment
check_env ( )
2023-09-23 10:27:27 -04:00
parser = argparse . ArgumentParser ( add_help = False )
2024-03-04 13:52:24 -05:00
parser . add_argument ( ' --update-wizard ' , action = ' store_true ' , help = ' Launch a menu with update options. ' )
2023-09-22 13:03:56 -04:00
args , _ = parser . parse_known_args ( )
2023-09-21 18:35:53 -04:00
2024-03-04 13:52:24 -05:00
if args . update_wizard :
2024-03-06 10:36:23 -05:00
while True :
choice = get_user_choice (
" What would you like to do? " ,
{
' A ' : ' Update the web UI ' ,
' B ' : ' Install/update extensions requirements ' ,
' C ' : ' Revert local changes to repository files with \" git reset --hard \" ' ,
' N ' : ' Nothing (exit) '
} ,
)
if choice == ' A ' :
update_requirements ( )
elif choice == ' B ' :
choices = { ' A ' : ' All extensions ' }
for i , name in enumerate ( get_extensions_names ( ) ) :
key = generate_alphabetic_sequence ( i + 1 )
choices [ key ] = name
choice = get_user_choice ( " What extension? " , choices )
if choice == ' A ' :
install_extensions_requirements ( )
else :
extension_req_path = os . path . join ( " extensions " , choices [ choice ] , " requirements.txt " )
run_cmd ( f " python -m pip install -r { extension_req_path } --upgrade " , assert_success = False , environment = True )
update_requirements ( pull = False )
elif choice == ' C ' :
run_cmd ( " git reset --hard " , assert_success = True , environment = True )
elif choice == ' N ' :
sys . exit ( )
2023-09-21 18:35:53 -04:00
else :
2023-09-21 23:12:16 -04:00
if not is_installed ( ) :
2023-09-22 15:08:05 -04:00
install_webui ( )
2023-09-21 18:35:53 -04:00
os . chdir ( script_dir )
2023-09-22 21:43:11 -04:00
if os . environ . get ( " LAUNCH_AFTER_INSTALL " , " " ) . lower ( ) in ( " no " , " n " , " false " , " 0 " , " f " , " off " ) :
2024-03-04 02:46:39 -05:00
print_big_message ( " Will now exit due to LAUNCH_AFTER_INSTALL. " )
2023-09-22 21:43:11 -04:00
sys . exit ( )
2023-09-21 18:35:53 -04:00
# Check if a model has been downloaded yet
2023-09-26 09:56:57 -04:00
if ' --model-dir ' in flags :
# Splits on ' ' or '=' while maintaining spaces within quotes
2023-09-28 16:56:15 -04:00
flags_list = re . split ( ' +(?=(?:[^ \" ]* \" [^ \" ]* \" )*[^ \" ]*$)|= ' , flags )
2024-01-04 21:41:54 -05:00
model_dir = [ flags_list [ ( flags_list . index ( flag ) + 1 ) ] for flag in flags_list if flag == ' --model-dir ' ] [ 0 ] . strip ( ' " \' ' )
2023-09-26 09:56:57 -04:00
else :
model_dir = ' models '
if len ( [ item for item in glob . glob ( f ' { model_dir } /* ' ) if not item . endswith ( ( ' .txt ' , ' .yaml ' ) ) ] ) == 0 :
2024-03-04 02:46:39 -05:00
print_big_message ( " You haven ' t downloaded any model yet. \n Once the web UI launches, head over to the \" Model \" tab and download one. " )
2023-09-21 18:35:53 -04:00
# Workaround for llama-cpp-python loading paths in CUDA env vars even if they do not exist
conda_path_bin = os . path . join ( conda_env_path , " bin " )
if not os . path . exists ( conda_path_bin ) :
os . mkdir ( conda_path_bin )
# Launch the webui
launch_webui ( )