diff --git a/modules/LoRA.py b/modules/LoRA.py index 0cb1671e..15132f4e 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -12,7 +12,7 @@ from modules.models import reload_model def add_lora_to_model(lora_names): if 'GPTQForCausalLM' in shared.model.__class__.__name__ or shared.args.loader == 'AutoGPTQ': add_lora_autogptq(lora_names) - elif shared.model.__class__.__name__ == 'Exllamav2HF' or shared.args.loader == 'ExLlamav2_HF': + elif shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav2HF'] or shared.args.loader in ['ExLlamav2', 'ExLlamav2_HF']: add_lora_exllamav2(lora_names) else: add_lora_transformers(lora_names) @@ -39,7 +39,11 @@ def add_lora_exllamav2(lora_names): shared.model.loras = [] for lora_name in lora_names: lora_path = get_lora_path(lora_name) - lora = ExLlamaV2Lora.from_directory(shared.model.ex_model, str(lora_path)) + if shared.model.__class__.__name__ == 'Exllamav2Model': + lora = ExLlamaV2Lora.from_directory(shared.model.model, str(lora_path)) + else: + lora = ExLlamaV2Lora.from_directory(shared.model.ex_model, str(lora_path)) + shared.model.loras.append(lora) shared.lora_names = lora_names diff --git a/modules/exllamav2.py b/modules/exllamav2.py new file mode 100644 index 00000000..551ed498 --- /dev/null +++ b/modules/exllamav2.py @@ -0,0 +1,149 @@ +import traceback +from pathlib import Path + +import torch +from exllamav2 import ( + ExLlamaV2, + ExLlamaV2Cache, + ExLlamaV2Cache_8bit, + ExLlamaV2Config, + ExLlamaV2Tokenizer +) +from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator + +from modules import shared +from modules.logging_colors import logger +from modules.text_generation import get_max_prompt_length + +try: + import flash_attn +except ModuleNotFoundError: + logger.warning( + 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage ' + 'to be a lot higher than it could be.\n' + 'Try installing flash-attention following the instructions here: ' + 'https://github.com/Dao-AILab/flash-attention#installation-and-features' + ) + pass +except Exception: + logger.warning('Failed to load flash-attention due to the following error:\n') + traceback.print_exc() + + +class Exllamav2Model: + def __init__(self): + pass + + @classmethod + def from_pretrained(self, path_to_model): + + path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model) + + config = ExLlamaV2Config() + config.model_dir = str(path_to_model) + config.prepare() + + config.max_seq_len = shared.args.max_seq_len + config.scale_pos_emb = shared.args.compress_pos_emb + config.scale_alpha_value = shared.args.alpha_value + config.no_flash_attn = shared.args.no_flash_attn + config.num_experts_per_token = int(shared.args.num_experts_per_token) + + model = ExLlamaV2(config) + + split = None + if shared.args.gpu_split: + split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + + model.load(split) + + tokenizer = ExLlamaV2Tokenizer(config) + if shared.args.cache_8bit: + cache = ExLlamaV2Cache_8bit(model) + else: + cache = ExLlamaV2Cache(model) + + generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer) + + result = self() + result.model = model + result.cache = cache + result.tokenizer = tokenizer + result.generator = generator + result.loras = None + return result, result + + def encode(self, string, **kwargs): + return self.tokenizer.encode(string, add_bos=True, encode_special_tokens=True) + + def decode(self, ids, **kwargs): + if isinstance(ids, list): + ids = torch.tensor([ids]) + elif isinstance(ids, torch.Tensor) and ids.numel() == 1: + ids = ids.view(1, -1) + + return self.tokenizer.decode(ids, decode_special_tokens=True)[0] + + def get_logits(self, token_ids, **kwargs): + self.cache.current_seq_len = 0 + if token_ids.shape[-1] > 1: + self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras) + + return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, loras=self.loras, **kwargs).float().cpu() + + def generate_with_streaming(self, prompt, state): + settings = ExLlamaV2Sampler.Settings() + + settings.token_repetition_penalty = state['repetition_penalty'] + settings.token_repetition_range = -1 if state['repetition_penalty_range'] <= 0 else state['repetition_penalty_range'] + + settings.token_frequency_penalty = state['frequency_penalty'] + settings.token_presence_penalty = state['presence_penalty'] + + settings.temperature = state['temperature'] + settings.top_k = state['top_k'] + settings.top_p = state['top_p'] + settings.top_a = state['top_a'] + settings.min_p = state['min_p'] + settings.tfs = state['tfs'] + settings.typical = state['typical_p'] + + settings.temperature_last = state['temperature_last'] + + settings.mirostat = state['mirostat_mode'] == 2 + settings.mirostat_tau = state['mirostat_tau'] + settings.mirostat_eta = state['mirostat_eta'] + + if state['ban_eos_token']: + settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id]) + + if state['custom_token_bans']: + to_ban = [int(x) for x in state['custom_token_bans'].split(',')] + if len(to_ban) > 0: + settings.disallow_tokens(self.tokenizer, to_ban) + + ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True) + ids = ids[:, -get_max_prompt_length(state):] + + if state['auto_max_new_tokens']: + max_new_tokens = state['truncation_length'] - ids.shape[-1] + else: + max_new_tokens = state['max_new_tokens'] + + self.generator.begin_stream(ids, settings, loras=self.loras) + + decoded_text = '' + for i in range(max_new_tokens): + chunk, eos, _ = self.generator.stream() + if eos or shared.stop_everything: + break + + decoded_text += chunk + yield decoded_text + + def generate(self, prompt, state): + output = '' + for output in self.generate_with_streaming(prompt, state): + pass + + return output diff --git a/modules/loaders.py b/modules/loaders.py index 687a9e92..26b7c5e2 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -83,6 +83,16 @@ loaders_and_params = OrderedDict({ 'trust_remote_code', 'no_use_fast', ], + 'ExLlamav2': [ + 'gpu_split', + 'max_seq_len', + 'no_flash_attn', + 'num_experts_per_token', + 'cache_8bit', + 'alpha_value', + 'compress_pos_emb', + 'exllamav2_info', + ], 'AutoGPTQ': [ 'triton', 'no_inject_fused_attention', @@ -197,6 +207,29 @@ loaders_samplers = { 'AutoAWQ': transformers_samplers(), 'QuIP#': transformers_samplers(), 'HQQ': transformers_samplers(), + 'ExLlamav2': { + 'temperature', + 'temperature_last', + 'top_p', + 'min_p', + 'top_k', + 'typical_p', + 'tfs', + 'top_a', + 'repetition_penalty', + 'presence_penalty', + 'frequency_penalty', + 'repetition_penalty_range', + 'seed', + 'mirostat_mode', + 'mirostat_tau', + 'mirostat_eta', + 'ban_eos_token', + 'add_bos_token', + 'custom_token_bans', + 'skip_special_tokens', + 'auto_max_new_tokens', + }, 'ExLlamav2_HF': { 'temperature', 'temperature_last', diff --git a/modules/logits.py b/modules/logits.py index c2cbd92e..c630be88 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -13,10 +13,11 @@ def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return logger.error("No model is loaded! Select one in the Model tab.") return 'Error: No model is loaded1 Select one in the Model tab.', previous + is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model' is_non_hf_llamacpp = shared.model.__class__.__name__ == 'LlamaCppModel' if use_samplers: - if is_non_hf_llamacpp: + if any([is_non_hf_exllamav2, is_non_hf_llamacpp]): logger.error("Sampler hijacking is not supported non-Huggingface loaders.") # sampling is all done in c for exllama, so it is really hard to hijack # it should be possible to hijack llamacpp sampler by hijacking all their sampling methods, @@ -30,7 +31,13 @@ def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return scores = sampler_hijack.global_scores[-1] else: - if is_non_hf_llamacpp: + if is_non_hf_exllamav2: + if is_torch_xpu_available(): + tokens = shared.tokenizer.encode(prompt).to("xpu:0") + else: + tokens = shared.tokenizer.encode(prompt).cuda() + scores = shared.model.get_logits(tokens)[-1][-1] + elif is_non_hf_llamacpp: tokens = shared.tokenizer.encode(prompt) scores = shared.model.get_logits(tokens)[-1][-1] else: @@ -38,7 +45,6 @@ def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("xpu:0") else: tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda() - output = shared.model(input_ids=tokens) scores = output['logits'][-1][-1] diff --git a/modules/models.py b/modules/models.py index ab0e762c..6c38c3c7 100644 --- a/modules/models.py +++ b/modules/models.py @@ -65,6 +65,7 @@ def load_model(model_name, loader=None): 'GPTQ-for-LLaMa': GPTQ_loader, 'llama.cpp': llamacpp_loader, 'llamacpp_HF': llamacpp_HF_loader, + 'ExLlamav2': ExLlamav2_loader, 'ExLlamav2_HF': ExLlamav2_HF_loader, 'ctransformers': ctransformers_loader, 'AutoAWQ': AutoAWQ_loader, @@ -375,6 +376,13 @@ def AutoGPTQ_loader(model_name): return modules.AutoGPTQ_loader.load_quantized(model_name) +def ExLlamav2_loader(model_name): + from modules.exllamav2 import Exllamav2Model + + model, tokenizer = Exllamav2Model.from_pretrained(model_name) + return model, tokenizer + + def ExLlamav2_HF_loader(model_name): from modules.exllamav2_hf import Exllamav2HF diff --git a/modules/models_settings.py b/modules/models_settings.py index 3e1649aa..9acc7efa 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -141,8 +141,6 @@ def get_model_metadata(model): if re.match(pat.lower(), model.lower()): for k in settings[pat]: model_settings[k] = settings[pat][k] - if k == 'loader' and settings[pat][k] == 'ExLlamav2': - model_settings[k] = 'ExLlamav2_HF' return model_settings diff --git a/modules/shared.py b/modules/shared.py index eea3d27f..78966617 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -89,7 +89,7 @@ group.add_argument('--chat-buttons', action='store_true', help='Show buttons on # Model loader group = parser.add_argument_group('Model loader') -group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlamav2_HF, AutoGPTQ, AutoAWQ, GPTQ-for-LLaMa, ctransformers, QuIP#.') +group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlamav2_HF, ExLlamav2, AutoGPTQ, AutoAWQ, GPTQ-for-LLaMa, ctransformers, QuIP#.') # Transformers/Accelerate group = parser.add_argument_group('Transformers/Accelerate') @@ -132,11 +132,11 @@ group.add_argument('--no_offload_kqv', action='store_true', help='Do not offload group.add_argument('--cache-capacity', type=str, help='Maximum cache capacity (llama-cpp-python). Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed.') group.add_argument('--row_split', action='store_true', help='Split multi-gpu by row instead of layer. Faster on some cards.') -# ExLlamaV2 -group = parser.add_argument_group('ExLlamaV2') +# ExLlama +group = parser.add_argument_group('ExLlama') group.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.') group.add_argument('--max_seq_len', type=int, default=2048, help='Maximum sequence length.') -group.add_argument('--cfg-cache', action='store_true', help='Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.') +group.add_argument('--cfg-cache', action='store_true', help='ExLlamav2_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.') group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.') group.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.') group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.') @@ -250,7 +250,11 @@ def fix_loader_name(name): return 'AutoGPTQ' elif name in ['gptq-for-llama', 'gptqforllama', 'gptqllama', 'gptq for llama', 'gptq_for_llama']: return 'GPTQ-for-LLaMa' - elif name in ['exllamav2', 'exllama-v2', 'ex_llama-v2', 'exlamav2', 'exlama-v2', 'exllama2', 'exllama-2', 'exllama', 'ex-llama', 'ex_llama', 'exlama', 'exllamav2-hf', 'exllamav2_hf', 'exllama-v2-hf', 'exllama_v2_hf', 'exllama-v2_hf', 'exllama2-hf', 'exllama2_hf', 'exllama-2-hf', 'exllama_2_hf', 'exllama-2_hf']: + elif name in ['exllama', 'ex-llama', 'ex_llama', 'exlama']: + return 'ExLlama' + elif name in ['exllamav2', 'exllama-v2', 'ex_llama-v2', 'exlamav2', 'exlama-v2', 'exllama2', 'exllama-2']: + return 'ExLlamav2' + elif name in ['exllamav2-hf', 'exllamav2_hf', 'exllama-v2-hf', 'exllama_v2_hf', 'exllama-v2_hf', 'exllama2-hf', 'exllama2_hf', 'exllama-2-hf', 'exllama_2_hf', 'exllama-2_hf']: return 'ExLlamav2_HF' elif name in ['ctransformers', 'ctranforemrs', 'ctransformer']: return 'ctransformers' diff --git a/modules/text_generation.py b/modules/text_generation.py index 1808f8bf..1917a0c1 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -45,7 +45,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap yield '' return - if shared.model.__class__.__name__ in ['LlamaCppModel', 'CtransformersModel']: + if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'CtransformersModel']: generate_func = generate_reply_custom else: generate_func = generate_reply_HF @@ -121,11 +121,10 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if shared.tokenizer is None: raise ValueError('No tokenizer is loaded') - if shared.model.__class__.__name__ in ['LlamaCppModel', 'CtransformersModel']: + if shared.model.__class__.__name__ in ['LlamaCppModel', 'CtransformersModel', 'Exllamav2Model']: input_ids = shared.tokenizer.encode(str(prompt)) - # The step below is necessary for llama.cpp, but may not be - # necessary for future loaders. - input_ids = np.array(input_ids).reshape(1, len(input_ids)) + if shared.model.__class__.__name__ not in ['Exllamav2Model']: + input_ids = np.array(input_ids).reshape(1, len(input_ids)) else: input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens) if not add_bos_token: @@ -136,7 +135,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if truncation_length is not None: input_ids = input_ids[:, -truncation_length:] - if shared.model.__class__.__name__ in ['LlamaCppModel', 'CtransformersModel'] or shared.args.cpu: + if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'CtransformersModel'] or shared.args.cpu: return input_ids elif shared.args.deepspeed: return input_ids.to(device=local_rank) diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 87f15c1d..23679097 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -142,6 +142,7 @@ def create_ui(): shared.gradio['disable_exllama'] = gr.Checkbox(label="disable_exllama", value=shared.args.disable_exllama, info='Disable ExLlama kernel for GPTQ models.') shared.gradio['disable_exllamav2'] = gr.Checkbox(label="disable_exllamav2", value=shared.args.disable_exllamav2, info='Disable ExLlamav2 kernel for GPTQ models.') shared.gradio['gptq_for_llama_info'] = gr.Markdown('Legacy loader for compatibility with older GPUs. ExLlamav2_HF or AutoGPTQ are preferred for GPTQ models when supported.') + shared.gradio['exllamav2_info'] = gr.Markdown("ExLlamav2_HF is recommended over ExLlamav2 for better integration with extensions and more consistent sampling behavior across loaders.") shared.gradio['llamacpp_HF_info'] = gr.Markdown('llamacpp_HF loads llama.cpp as a Transformers model. To use it, you need to download a tokenizer.\n\nOption 1 (recommended): place your .gguf in a subfolder of models/ along with these 4 files: special_tokens_map.json, tokenizer_config.json, tokenizer.json, tokenizer.model.\n\nOption 2: download `oobabooga/llama-tokenizer` under "Download model or LoRA". That\'s a default Llama tokenizer that will work for some (but not all) models.') with gr.Column():