Force only 1 llama-cpp-python version at a time for now

This commit is contained in:
oobabooga 2024-07-04 19:43:34 -07:00
parent f243b4ca9c
commit a47de06088

View File

@ -7,30 +7,47 @@ from modules import shared
from modules.cache_utils import process_llamacpp_cache
imported_module = None
def llama_cpp_lib():
global imported_module
return_lib = None
if shared.args.cpu:
if imported_module and imported_module != 'llama_cpp':
raise Exception(f"Cannot import 'llama_cpp' because '{imported_module}' is already imported. See issue #1575 in llama-cpp-python. Please restart the server before attempting to use a different version of llama-cpp-python.")
try:
return_lib = importlib.import_module('llama_cpp')
imported_module = 'llama_cpp'
except:
pass
if shared.args.tensorcores and return_lib is None:
if imported_module and imported_module != 'llama_cpp_cuda_tensorcores':
raise Exception(f"Cannot import 'llama_cpp_cuda_tensorcores' because '{imported_module}' is already imported. See issue #1575 in llama-cpp-python. Please restart the server before attempting to use a different version of llama-cpp-python.")
try:
return_lib = importlib.import_module('llama_cpp_cuda_tensorcores')
imported_module = 'llama_cpp_cuda_tensorcores'
except:
pass
if return_lib is None:
if imported_module and imported_module != 'llama_cpp_cuda':
raise Exception(f"Cannot import 'llama_cpp_cuda' because '{imported_module}' is already imported. See issue #1575 in llama-cpp-python. Please restart the server before attempting to use a different version of llama-cpp-python.")
try:
return_lib = importlib.import_module('llama_cpp_cuda')
imported_module = 'llama_cpp_cuda'
except:
pass
if return_lib is None and not shared.args.cpu:
if imported_module and imported_module != 'llama_cpp':
raise Exception(f"Cannot import 'llama_cpp' because '{imported_module}' is already imported. See issue #1575 in llama-cpp-python. Please restart the server before attempting to use a different version of llama-cpp-python.")
try:
return_lib = importlib.import_module('llama_cpp')
imported_module = 'llama_cpp'
except:
pass