2023-04-07 20:36:04 -04:00
import gc
2023-02-23 11:28:30 -05:00
import json
2023-05-03 20:43:17 -04:00
import logging
2023-02-23 11:28:30 -05:00
import os
2023-03-19 18:21:41 -04:00
import re
2023-02-23 11:28:30 -05:00
import time
import zipfile
from pathlib import Path
import numpy as np
import torch
2023-02-23 11:42:23 -05:00
import transformers
2023-03-16 12:34:23 -04:00
from accelerate import infer_auto_device_map , init_empty_weights
2023-04-16 18:15:03 -04:00
from transformers import ( AutoConfig , AutoModel , AutoModelForCausalLM ,
2023-04-25 21:39:04 -04:00
AutoModelForSeq2SeqLM , AutoTokenizer ,
BitsAndBytesConfig , LlamaTokenizer )
2023-02-23 12:41:42 -05:00
import modules . shared as shared
2023-04-09 22:08:40 -04:00
from modules import llama_attn_hijack
2023-02-23 11:28:30 -05:00
2023-02-23 11:42:23 -05:00
transformers . logging . set_verbosity_error ( )
2023-02-23 11:28:30 -05:00
if shared . args . flexgen :
2023-03-16 09:18:34 -04:00
from flexgen . flex_opt import CompressionConfig , ExecutionEnv , OptLM , Policy
2023-02-23 11:28:30 -05:00
2023-04-07 20:36:04 -04:00
local_rank = None
2023-02-23 11:28:30 -05:00
if shared . args . deepspeed :
import deepspeed
2023-02-23 12:41:42 -05:00
from transformers . deepspeed import ( HfDeepSpeedConfig ,
is_deepspeed_zero3_enabled )
2023-02-23 11:28:30 -05:00
from modules . deepspeed_parameters import generate_ds_config
# Distributed setup
local_rank = shared . args . local_rank if shared . args . local_rank is not None else int ( os . getenv ( " LOCAL_RANK " , " 0 " ) )
world_size = int ( os . getenv ( " WORLD_SIZE " , " 1 " ) )
torch . cuda . set_device ( local_rank )
deepspeed . init_distributed ( )
ds_config = generate_ds_config ( shared . args . bf16 , 1 * world_size , shared . args . nvme_offload_dir )
2023-04-06 23:15:45 -04:00
dschf = HfDeepSpeedConfig ( ds_config ) # Keep this object alive for the Transformers integration
2023-02-23 11:28:30 -05:00
2023-03-13 13:00:38 -04:00
2023-04-22 13:56:48 -04:00
def find_model_type ( model_name ) :
2023-04-26 00:55:40 -04:00
model_name_lower = model_name . lower ( )
if ' rwkv- ' in model_name_lower :
2023-04-22 13:56:48 -04:00
return ' rwkv '
elif len ( list ( Path ( f ' { shared . args . model_dir } / { model_name } ' ) . glob ( ' *ggml*.bin ' ) ) ) > 0 :
return ' llamacpp '
2023-04-26 00:55:40 -04:00
elif re . match ( ' .*ggml.* \ .bin ' , model_name_lower ) :
2023-04-22 13:56:48 -04:00
return ' llamacpp '
2023-04-26 00:55:40 -04:00
elif ' chatglm ' in model_name_lower :
2023-04-22 13:56:48 -04:00
return ' chatglm '
2023-04-26 00:55:40 -04:00
elif ' galactica ' in model_name_lower :
2023-04-22 13:56:48 -04:00
return ' galactica '
2023-04-26 00:55:40 -04:00
elif ' llava ' in model_name_lower :
2023-04-23 19:32:22 -04:00
return ' llava '
2023-05-04 14:55:39 -04:00
elif ' oasst ' in model_name_lower :
return ' oasst '
2023-04-26 00:55:40 -04:00
elif any ( ( k in model_name_lower for k in [ ' gpt4chan ' , ' gpt-4chan ' ] ) ) :
2023-04-22 13:56:48 -04:00
return ' gpt4chan '
else :
2023-05-04 01:01:28 -04:00
config = AutoConfig . from_pretrained ( Path ( f ' { shared . args . model_dir } / { model_name } ' ) , trust_remote_code = shared . args . trust_remote_code )
2023-04-25 21:39:04 -04:00
# Not a "catch all", but fairly accurate
if config . to_dict ( ) . get ( " is_encoder_decoder " , False ) :
return ' HF_seq2seq '
else :
return ' HF_generic '
2023-04-22 13:56:48 -04:00
2023-02-23 11:28:30 -05:00
def load_model ( model_name ) :
2023-05-03 20:43:17 -04:00
logging . info ( f " Loading { model_name } ... " )
2023-02-23 11:28:30 -05:00
t0 = time . time ( )
2023-04-22 13:56:48 -04:00
shared . model_type = find_model_type ( model_name )
2023-05-04 01:01:28 -04:00
trust_remote_code = shared . args . trust_remote_code
2023-04-22 13:56:48 -04:00
if shared . model_type == ' chatglm ' :
2023-04-16 18:15:03 -04:00
LoaderClass = AutoModel
2023-04-25 21:39:04 -04:00
elif shared . model_type == ' HF_seq2seq ' :
LoaderClass = AutoModelForSeq2SeqLM
2023-04-16 18:15:03 -04:00
else :
LoaderClass = AutoModelForCausalLM
2023-02-27 21:03:35 -05:00
2023-04-15 11:54:02 -04:00
# Load the model in simple 16-bit mode by default
2023-04-22 13:56:48 -04:00
if not any ( [ shared . args . cpu , shared . args . load_in_8bit , shared . args . wbits , shared . args . auto_devices , shared . args . disk , shared . args . gpu_memory is not None , shared . args . cpu_memory is not None , shared . args . deepspeed , shared . args . flexgen , shared . model_type in [ ' rwkv ' , ' llamacpp ' ] ] ) :
2023-04-20 23:20:33 -04:00
model = LoaderClass . from_pretrained ( Path ( f " { shared . args . model_dir } / { model_name } " ) , low_cpu_mem_usage = True , torch_dtype = torch . bfloat16 if shared . args . bf16 else torch . float16 , trust_remote_code = trust_remote_code )
2023-04-15 11:54:02 -04:00
if torch . has_mps :
device = torch . device ( ' mps ' )
model = model . to ( device )
2023-03-17 21:27:26 -04:00
else :
2023-04-15 11:54:02 -04:00
model = model . cuda ( )
2023-03-17 21:27:26 -04:00
2023-02-23 11:28:30 -05:00
# FlexGen
elif shared . args . flexgen :
2023-02-26 14:53:41 -05:00
# Initialize environment
env = ExecutionEnv . create ( shared . args . disk_cache_dir )
2023-02-23 11:28:30 -05:00
# Offloading policy
policy = Policy ( 1 , 1 ,
shared . args . percent [ 0 ] , shared . args . percent [ 1 ] ,
shared . args . percent [ 2 ] , shared . args . percent [ 3 ] ,
shared . args . percent [ 4 ] , shared . args . percent [ 5 ] ,
2023-03-03 23:04:02 -05:00
overlap = True , sep_layer = True , pin_weight = shared . args . pin_weight ,
2023-02-23 11:28:30 -05:00
cpu_cache_compute = False , attn_sparsity = 1.0 ,
compress_weight = shared . args . compress_weight ,
comp_weight_config = CompressionConfig (
num_bits = 4 , group_size = 64 ,
group_dim = 0 , symmetric = False ) ,
compress_cache = False ,
comp_cache_config = CompressionConfig (
num_bits = 4 , group_size = 64 ,
group_dim = 2 , symmetric = False ) )
2023-04-20 23:20:33 -04:00
model = OptLM ( f " facebook/ { model_name } " , env , shared . args . model_dir , policy )
2023-02-23 11:28:30 -05:00
# DeepSpeed ZeRO-3
elif shared . args . deepspeed :
2023-04-20 23:20:33 -04:00
model = LoaderClass . from_pretrained ( Path ( f " { shared . args . model_dir } / { model_name } " ) , torch_dtype = torch . bfloat16 if shared . args . bf16 else torch . float16 )
2023-02-23 11:28:30 -05:00
model = deepspeed . initialize ( model = model , config_params = ds_config , model_parameters = None , optimizer = None , lr_scheduler = None ) [ 0 ]
2023-04-06 23:15:45 -04:00
model . module . eval ( ) # Inference
2023-05-03 20:43:17 -04:00
logging . info ( f " DeepSpeed ZeRO-3 is enabled: { is_deepspeed_zero3_enabled ( ) } " )
2023-02-23 11:28:30 -05:00
2023-02-27 21:03:35 -05:00
# RMKV model (not on HuggingFace)
2023-04-22 13:56:48 -04:00
elif shared . model_type == ' rwkv ' :
2023-03-06 06:45:49 -05:00
from modules . RWKV import RWKVModel , RWKVTokenizer
2023-02-27 21:03:35 -05:00
2023-03-24 20:30:18 -04:00
model = RWKVModel . from_pretrained ( Path ( f ' { shared . args . model_dir } / { model_name } ' ) , dtype = " fp32 " if shared . args . cpu else " bf16 " if shared . args . bf16 else " fp16 " , device = " cpu " if shared . args . cpu else " cuda " )
2023-03-27 22:42:29 -04:00
tokenizer = RWKVTokenizer . from_pretrained ( Path ( shared . args . model_dir ) )
2023-03-01 10:08:55 -05:00
2023-03-06 06:45:49 -05:00
return model , tokenizer
2023-02-27 21:03:35 -05:00
2023-04-17 09:47:26 -04:00
# llamacpp model
2023-04-22 13:56:48 -04:00
elif shared . model_type == ' llamacpp ' :
2023-05-02 17:25:28 -04:00
from modules . llamacpp_model import LlamaCppModel
2023-04-17 09:47:26 -04:00
2023-04-22 13:56:48 -04:00
path = Path ( f ' { shared . args . model_dir } / { model_name } ' )
if path . is_file ( ) :
model_file = path
else :
model_file = list ( Path ( f ' { shared . args . model_dir } / { model_name } ' ) . glob ( ' *ggml*.bin ' ) ) [ 0 ]
2023-04-17 09:47:26 -04:00
2023-05-03 20:43:17 -04:00
logging . info ( f " llama.cpp weights detected: { model_file } \n " )
2023-04-17 09:47:26 -04:00
model , tokenizer = LlamaCppModel . from_pretrained ( model_file )
return model , tokenizer
2023-03-13 13:00:38 -04:00
# Quantized model
2023-03-25 23:11:33 -04:00
elif shared . args . wbits > 0 :
2023-03-10 07:29:09 -05:00
2023-04-16 22:26:52 -04:00
# Monkey patch
if shared . args . monkey_patch :
2023-05-03 21:06:46 -04:00
logging . warning ( " Applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope. " )
2023-04-16 22:26:52 -04:00
from modules . monkey_patch_gptq_lora import load_model_llama
2023-04-25 22:18:11 -04:00
model , _ = load_model_llama ( model_name )
2023-04-16 22:26:52 -04:00
# No monkey patch
else :
from modules . GPTQ_loader import load_quantized
model = load_quantized ( model_name )
2023-03-09 13:50:26 -05:00
2023-02-23 11:28:30 -05:00
# Custom
else :
2023-05-09 18:22:10 -04:00
params = {
2023-05-09 21:49:39 -04:00
" low_cpu_mem_usage " : True ,
" trust_remote_code " : trust_remote_code
2023-05-09 18:22:10 -04:00
}
2023-05-09 21:49:39 -04:00
2023-03-17 21:56:46 -04:00
if not any ( ( shared . args . cpu , torch . cuda . is_available ( ) , torch . has_mps ) ) :
2023-05-03 21:06:46 -04:00
logging . warning ( " torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode. " )
2023-02-23 11:28:30 -05:00
shared . args . cpu = True
if shared . args . cpu :
2023-03-16 11:42:53 -04:00
params [ " torch_dtype " ] = torch . float32
2023-02-23 11:28:30 -05:00
else :
2023-03-16 11:42:53 -04:00
params [ " device_map " ] = ' auto '
2023-03-16 17:22:16 -04:00
if shared . args . load_in_8bit and any ( ( shared . args . auto_devices , shared . args . gpu_memory ) ) :
2023-03-16 11:42:53 -04:00
params [ ' quantization_config ' ] = BitsAndBytesConfig ( load_in_8bit = True , llm_int8_enable_fp32_cpu_offload = True )
2023-03-16 17:22:16 -04:00
elif shared . args . load_in_8bit :
params [ ' quantization_config ' ] = BitsAndBytesConfig ( load_in_8bit = True )
2023-03-16 11:42:53 -04:00
elif shared . args . bf16 :
params [ " torch_dtype " ] = torch . bfloat16
else :
params [ " torch_dtype " ] = torch . float16
2023-02-23 11:28:30 -05:00
if shared . args . gpu_memory :
2023-04-06 23:15:45 -04:00
memory_map = list ( map ( lambda x : x . strip ( ) , shared . args . gpu_memory ) )
2023-03-19 18:21:41 -04:00
max_cpu_memory = shared . args . cpu_memory . strip ( ) if shared . args . cpu_memory is not None else ' 99GiB '
2023-03-16 12:34:23 -04:00
max_memory = { }
for i in range ( len ( memory_map ) ) :
2023-03-19 18:21:41 -04:00
max_memory [ i ] = f ' { memory_map [ i ] } GiB ' if not re . match ( ' .*ib$ ' , memory_map [ i ] . lower ( ) ) else memory_map [ i ]
2023-05-03 20:43:17 -04:00
2023-03-19 18:21:41 -04:00
max_memory [ ' cpu ' ] = max_cpu_memory
2023-03-16 11:42:53 -04:00
params [ ' max_memory ' ] = max_memory
2023-03-16 17:22:16 -04:00
elif shared . args . auto_devices :
2023-04-06 23:15:45 -04:00
total_mem = ( torch . cuda . get_device_properties ( 0 ) . total_memory / ( 1024 * 1024 ) )
suggestion = round ( ( total_mem - 1000 ) / 1000 ) * 1000
2023-03-16 11:42:53 -04:00
if total_mem - suggestion < 800 :
2023-02-23 11:28:30 -05:00
suggestion - = 1000
2023-04-06 23:15:45 -04:00
2023-05-03 20:43:17 -04:00
suggestion = int ( round ( suggestion / 1000 ) )
logging . warning ( f " \033 [1;32;1mAuto-assiging --gpu-memory { suggestion } for your GPU to try to prevent out-of-memory errors. \n You can manually set other values. \033 [0;37;0m " )
2023-03-16 12:34:23 -04:00
max_memory = { 0 : f ' { suggestion } GiB ' , ' cpu ' : f ' { shared . args . cpu_memory or 99 } GiB ' }
2023-03-16 11:42:53 -04:00
params [ ' max_memory ' ] = max_memory
2023-02-23 11:28:30 -05:00
2023-03-16 11:42:53 -04:00
if shared . args . disk :
params [ " offload_folder " ] = shared . args . disk_cache_dir
2023-04-20 23:20:33 -04:00
checkpoint = Path ( f ' { shared . args . model_dir } / { model_name } ' )
2023-03-16 11:42:53 -04:00
if shared . args . load_in_8bit and params . get ( ' max_memory ' , None ) is not None and params [ ' device_map ' ] == ' auto ' :
2023-05-07 22:48:20 -04:00
config = AutoConfig . from_pretrained ( checkpoint , trust_remote_code = trust_remote_code )
2023-03-16 11:42:53 -04:00
with init_empty_weights ( ) :
2023-05-07 22:48:20 -04:00
model = LoaderClass . from_config ( config , trust_remote_code = trust_remote_code )
2023-05-03 20:43:17 -04:00
2023-03-16 11:42:53 -04:00
model . tie_weights ( )
params [ ' device_map ' ] = infer_auto_device_map (
2023-04-06 23:15:45 -04:00
model ,
dtype = torch . int8 ,
2023-03-16 11:42:53 -04:00
max_memory = params [ ' max_memory ' ] ,
2023-04-06 23:15:45 -04:00
no_split_module_classes = model . _no_split_modules
2023-03-16 11:42:53 -04:00
)
2023-04-16 18:15:03 -04:00
model = LoaderClass . from_pretrained ( checkpoint , * * params )
2023-02-23 11:28:30 -05:00
2023-04-09 22:08:40 -04:00
# Hijack attention with xformers
if any ( ( shared . args . xformers , shared . args . sdp_attention ) ) :
llama_attn_hijack . hijack_llama_attention ( )
2023-02-23 11:28:30 -05:00
# Loading the tokenizer
2023-04-22 13:56:48 -04:00
if shared . model_type == ' gpt4chan ' and Path ( f " { shared . args . model_dir } /gpt-j-6B/ " ) . exists ( ) :
2023-03-24 20:30:18 -04:00
tokenizer = AutoTokenizer . from_pretrained ( Path ( f " { shared . args . model_dir } /gpt-j-6B/ " ) )
2023-04-06 15:04:03 -04:00
elif type ( model ) is transformers . LlamaForCausalLM :
2023-04-19 20:23:51 -04:00
tokenizer = None
# Try to load an universal LLaMA tokenizer
2023-05-04 14:55:39 -04:00
if shared . model_type not in [ ' llava ' , ' oasst ' ] :
2023-04-23 19:32:22 -04:00
for p in [ Path ( f " { shared . args . model_dir } /llama-tokenizer/ " ) , Path ( f " { shared . args . model_dir } /oobabooga_llama-tokenizer/ " ) ] :
if p . exists ( ) :
2023-05-03 20:43:17 -04:00
logging . info ( f " Loading the universal LLaMA tokenizer from { p } ... " )
2023-04-23 19:32:22 -04:00
tokenizer = LlamaTokenizer . from_pretrained ( p , clean_up_tokenization_spaces = True )
break
2023-04-19 20:23:51 -04:00
# Otherwise, load it from the model folder and hope that these
# are not outdated tokenizer files.
if tokenizer is None :
2023-04-20 23:20:33 -04:00
tokenizer = LlamaTokenizer . from_pretrained ( Path ( f " { shared . args . model_dir } / { model_name } / " ) , clean_up_tokenization_spaces = True )
2023-04-19 20:23:51 -04:00
try :
tokenizer . eos_token_id = 2
tokenizer . bos_token_id = 1
tokenizer . pad_token_id = 0
except :
pass
2023-02-23 11:28:30 -05:00
else :
2023-04-20 23:20:33 -04:00
tokenizer = AutoTokenizer . from_pretrained ( Path ( f " { shared . args . model_dir } / { model_name } / " ) , trust_remote_code = trust_remote_code )
2023-02-23 11:28:30 -05:00
2023-05-03 20:43:17 -04:00
logging . info ( f " Loaded the model in { ( time . time ( ) - t0 ) : .2f } seconds. " )
2023-02-23 11:28:30 -05:00
return model , tokenizer
2023-04-06 23:15:45 -04:00
2023-04-07 20:36:04 -04:00
def clear_torch_cache ( ) :
gc . collect ( )
if not shared . args . cpu :
torch . cuda . empty_cache ( )
def unload_model ( ) :
shared . model = shared . tokenizer = None
clear_torch_cache ( )
def reload_model ( ) :
2023-04-07 20:37:41 -04:00
unload_model ( )
2023-04-07 20:36:04 -04:00
shared . model , shared . tokenizer = load_model ( shared . model_name )
2023-02-23 11:28:30 -05:00
def load_soft_prompt ( name ) :
if name == ' None ' :
shared . soft_prompt = False
shared . soft_prompt_tensor = None
else :
with zipfile . ZipFile ( Path ( f ' softprompts/ { name } .zip ' ) ) as zf :
zf . extract ( ' tensor.npy ' )
zf . extract ( ' meta.json ' )
j = json . loads ( open ( ' meta.json ' , ' r ' ) . read ( ) )
2023-05-03 20:43:17 -04:00
logging . info ( f " \n Loading the softprompt \" { name } \" . " )
2023-02-23 11:28:30 -05:00
for field in j :
if field != ' name ' :
if type ( j [ field ] ) is list :
2023-05-03 20:43:17 -04:00
logging . info ( f " { field } : { ' , ' . join ( j [ field ] ) } " )
2023-02-23 11:28:30 -05:00
else :
2023-05-03 20:43:17 -04:00
logging . info ( f " { field } : { j [ field ] } " )
2023-05-09 21:49:39 -04:00
2023-05-03 20:43:17 -04:00
logging . info ( )
2023-02-23 11:28:30 -05:00
tensor = np . load ( ' tensor.npy ' )
Path ( ' tensor.npy ' ) . unlink ( )
Path ( ' meta.json ' ) . unlink ( )
2023-05-03 20:43:17 -04:00
2023-02-23 11:28:30 -05:00
tensor = torch . Tensor ( tensor ) . to ( device = shared . model . device , dtype = shared . model . dtype )
tensor = torch . reshape ( tensor , ( 1 , tensor . shape [ 0 ] , tensor . shape [ 1 ] ) )
shared . soft_prompt = True
shared . soft_prompt_tensor = tensor
return name