2024-07-04 22:15:37 -04:00
import importlib
2024-07-12 23:04:19 -04:00
import platform
2024-07-22 21:05:11 -04:00
from typing import Sequence
from tqdm import tqdm
2024-02-08 00:40:58 -05:00
2024-03-08 22:25:33 -05:00
from modules import shared
from modules . cache_utils import process_llamacpp_cache
2024-07-22 21:05:11 -04:00
2024-07-04 22:43:34 -04:00
imported_module = None
2024-07-04 22:15:37 -04:00
def llama_cpp_lib ( ) :
2024-07-04 22:43:34 -04:00
global imported_module
2024-07-12 23:04:19 -04:00
# Determine the platform
is_macos = platform . system ( ) == ' Darwin '
# Define the library names based on the platform
if is_macos :
lib_names = [
( None , ' llama_cpp ' )
]
else :
lib_names = [
( ' cpu ' , ' llama_cpp ' ) ,
2024-07-22 21:05:11 -04:00
( ' tensorcores ' , ' llama_cpp_cuda_tensorcores ' ) ,
2024-07-12 23:04:19 -04:00
( None , ' llama_cpp_cuda ' ) ,
( None , ' llama_cpp ' )
]
for arg , lib_name in lib_names :
should_import = ( arg is None or getattr ( shared . args , arg ) )
if should_import :
if imported_module and imported_module != lib_name :
# Conflict detected, raise an exception
raise Exception ( f " Cannot import ` { lib_name } ` because ` { imported_module } ` is already imported. Switching to a different version of llama-cpp-python currently requires a server restart. " )
try :
return_lib = importlib . import_module ( lib_name )
imported_module = lib_name
monkey_patch_llama_cpp_python ( return_lib )
return return_lib
except ImportError :
continue
return None
2024-04-30 08:11:31 -04:00
2024-02-08 00:40:58 -05:00
2024-07-22 21:05:11 -04:00
def eval_with_progress ( self , tokens : Sequence [ int ] ) :
"""
A copy of
https : / / github . com / abetlen / llama - cpp - python / blob / main / llama_cpp / llama . py
with tqdm to show prompt processing progress .
"""
assert self . _ctx . ctx is not None
assert self . _batch . batch is not None
self . _ctx . kv_cache_seq_rm ( - 1 , self . n_tokens , - 1 )
2024-09-03 20:37:06 -04:00
if len ( tokens ) > self . n_batch :
2024-07-22 21:05:11 -04:00
progress_bar = tqdm ( range ( 0 , len ( tokens ) , self . n_batch ) , desc = " Prompt evaluation " , leave = False )
else :
progress_bar = range ( 0 , len ( tokens ) , self . n_batch )
for i in progress_bar :
batch = tokens [ i : min ( len ( tokens ) , i + self . n_batch ) ]
n_past = self . n_tokens
n_tokens = len ( batch )
self . _batch . set_batch (
batch = batch , n_past = n_past , logits_all = self . context_params . logits_all
)
self . _ctx . decode ( self . _batch )
# Save tokens
self . input_ids [ n_past : n_past + n_tokens ] = batch
# Save logits
if self . context_params . logits_all :
rows = n_tokens
cols = self . _n_vocab
logits = self . _ctx . get_logits ( ) [ : rows * cols ]
self . scores [ n_past : n_past + n_tokens , : ] . reshape ( - 1 ) [ : : ] = logits
else :
rows = 1
cols = self . _n_vocab
logits = self . _ctx . get_logits ( ) [ : rows * cols ]
self . scores [ n_past + n_tokens - 1 , : ] . reshape ( - 1 ) [ : : ] = logits
# Update n_tokens
self . n_tokens + = n_tokens
2024-07-04 22:15:37 -04:00
def monkey_patch_llama_cpp_python ( lib ) :
2024-07-05 06:34:15 -04:00
if getattr ( lib . Llama , ' _is_patched ' , False ) :
# If the patch is already applied, do nothing
return
2024-03-08 22:25:33 -05:00
def my_generate ( self , * args , * * kwargs ) :
if shared . args . streaming_llm :
new_sequence = args [ 0 ]
past_sequence = self . _input_ids
# Do the cache trimming for StreamingLLM
process_llamacpp_cache ( self , new_sequence , past_sequence )
for output in self . original_generate ( * args , * * kwargs ) :
yield output
2024-07-22 21:05:11 -04:00
lib . Llama . eval = eval_with_progress
2024-03-08 22:25:33 -05:00
lib . Llama . original_generate = lib . Llama . generate
lib . Llama . generate = my_generate
2024-07-05 06:34:15 -04:00
# Set the flag to indicate that the patch has been applied
lib . Llama . _is_patched = True