From c2a309f56e0160b268f0be155fd3e168f1c4bb93 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 12 Sep 2023 14:33:07 -0300 Subject: [PATCH] Add ExLlamaV2 and ExLlamav2_HF loaders (#3881) --- models/config.yaml | 2 +- modules/exllamav2.py | 102 +++++++++++++++++++++++++++++++ modules/exllamav2_hf.py | 119 +++++++++++++++++++++++++++++++++++++ modules/loaders.py | 45 ++++++++++++++ modules/models.py | 15 +++++ modules/shared.py | 4 ++ modules/text_generation.py | 9 +-- requirements.txt | 2 + requirements_nocuda.txt | 2 + 9 files changed, 295 insertions(+), 5 deletions(-) create mode 100644 modules/exllamav2.py create mode 100644 modules/exllamav2_hf.py diff --git a/models/config.yaml b/models/config.yaml index a2eb0146..e69e2562 100644 --- a/models/config.yaml +++ b/models/config.yaml @@ -210,7 +210,7 @@ llama-65b-gptq-3bit: instruction_template: 'Alpaca' .*llama-(2|v2): truncation_length: 4096 -.*llama-(2|v2).*chat: +.*llama(-?)(2|v2).*chat: instruction_template: 'Llama-v2' .*newhope: instruction_template: 'NewHope' diff --git a/modules/exllamav2.py b/modules/exllamav2.py new file mode 100644 index 00000000..4f89e0e6 --- /dev/null +++ b/modules/exllamav2.py @@ -0,0 +1,102 @@ +import random +from pathlib import Path + +import torch +from exllamav2 import ( + ExLlamaV2, + ExLlamaV2Cache, + ExLlamaV2Config, + ExLlamaV2Tokenizer +) +from exllamav2.generator import ExLlamaV2BaseGenerator, ExLlamaV2Sampler + +from modules import shared +from modules.text_generation import get_max_prompt_length + + +class Exllamav2Model: + def __init__(self): + pass + + @classmethod + def from_pretrained(self, path_to_model): + + path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model) + + config = ExLlamaV2Config() + config.model_dir = path_to_model + config.prepare() + + config.max_seq_len = shared.args.max_seq_len + model = ExLlamaV2(config) + + split = None + if shared.args.gpu_split: + split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + + model.load(split) + + tokenizer = ExLlamaV2Tokenizer(config) + cache = ExLlamaV2Cache(model) + generator = ExLlamaV2BaseGenerator(model, cache, tokenizer) + + result = self() + result.model = model + result.cache = cache + result.tokenizer = tokenizer + result.generator = generator + return result, tokenizer + + def generate_with_streaming(self, prompt, state): + settings = ExLlamaV2Sampler.Settings() + settings.temperature = state['temperature'] + settings.top_k = state['top_k'] + settings.top_p = state['top_p'] + settings.token_repetition_penalty = state['repetition_penalty'] + settings.token_repetition_range = -1 if state['repetition_penalty_range'] <= 0 else state['repetition_penalty_range'] + if state['ban_eos_token']: + settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id]) + + ids = self.tokenizer.encode(prompt) + ids = ids[:, -get_max_prompt_length(state):] + initial_len = ids.shape[-1] + + if state['auto_max_new_tokens']: + max_new_tokens = state['truncation_length'] - ids.shape[-1] + else: + max_new_tokens = state['max_new_tokens'] + + # _gen_begin_base + self.cache.current_seq_len = 0 + self.model.forward(ids[:, :-1], self.cache, input_mask=None, preprocess_only=True) + + has_leading_space = False + for i in range(max_new_tokens): + logits = self.model.forward(ids[:, -1:], self.cache, input_mask=None).float().cpu() + token, _ = ExLlamaV2Sampler.sample(logits, settings, ids, random.random()) + ids = torch.cat([ids, token], dim=1) + + if i == 0 and self.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'): + has_leading_space = True + + decoded_text = self.tokenizer.decode(ids[:, initial_len:])[0] + if has_leading_space: + decoded_text = ' ' + decoded_text + + yield decoded_text + + if token.item() == self.tokenizer.eos_token_id or shared.stop_everything: + break + + def generate(self, prompt, state): + output = '' + for output in self.generate_with_streaming(prompt, state): + pass + + return output + + def encode(self, string, **kwargs): + return self.tokenizer.encode(string) + + def decode(self, string, **kwargs): + return self.tokenizer.decode(string)[0] diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py new file mode 100644 index 00000000..2eb2d087 --- /dev/null +++ b/modules/exllamav2_hf.py @@ -0,0 +1,119 @@ +import os +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import torch +from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config +from torch.nn import CrossEntropyLoss +from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast + +from modules import shared +from modules.logging_colors import logger + + +class Exllamav2HF(PreTrainedModel): + def __init__(self, config: ExLlamaV2Config): + super().__init__(PretrainedConfig()) + self.ex_config = config + self.ex_model = ExLlamaV2(config) + split = None + if shared.args.gpu_split: + split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + + self.ex_model.load(split) + + self.generation_config = GenerationConfig() + + self.ex_cache = ExLlamaV2Cache(self.ex_model) + self.past_seq = None + + if shared.args.cfg_cache: + self.ex_cache_negative = ExLlamaV2Cache(self.ex_model) + self.past_seq_negative = None + + def _validate_model_class(self): + pass + + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + pass + + def prepare_inputs_for_generation(self, input_ids, **kwargs): + return {'input_ids': input_ids, **kwargs} + + @property + def device(self) -> torch.device: + return torch.device(0) + + def __call__(self, *args, **kwargs): + use_cache = kwargs.get('use_cache', True) + labels = kwargs.get('labels', None) + past_key_values = kwargs.get('past_key_values', None) + + if len(args) > 0: + if not shared.args.cfg_cache: + logger.error("Please enable the cfg-cache option to use CFG with ExLlamav2_HF.") + return + + input_ids = args[0] + is_negative = True + past_seq = self.past_seq_negative + ex_cache = self.ex_cache_negative + else: + input_ids = kwargs['input_ids'] + is_negative = False + past_seq = self.past_seq + ex_cache = self.ex_cache + + seq = input_ids[0].tolist() + if is_negative and past_key_values is not None: + seq = past_key_values + seq + + seq_tensor = torch.tensor(seq) + + # Make the forward call + if labels is None: + if past_seq is None or not torch.equal(past_seq, seq_tensor[:-1]): + ex_cache.current_seq_len = 0 + self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), ex_cache, preprocess_only=True) + + logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), ex_cache).to(input_ids.device) + else: + ex_cache.current_seq_len = 0 + # logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache, last_id_only=False) + logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache) + + if is_negative: + self.past_seq_negative = seq_tensor + else: + self.past_seq = seq_tensor + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, logits.shape[-1]) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): + assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported" + if isinstance(pretrained_model_name_or_path, str): + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + + pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path) + + config = ExLlamaV2Config() + config.model_dir = pretrained_model_name_or_path + config.prepare() + config.max_seq_len = shared.args.max_seq_len + + return Exllamav2HF(config) diff --git a/modules/loaders.py b/modules/loaders.py index 03327fb4..28882a6a 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -42,6 +42,15 @@ loaders_and_params = OrderedDict({ 'compress_pos_emb', 'exllama_info', ], + 'ExLlamav2': [ + 'gpu_split', + 'max_seq_len', + ], + 'ExLlamav2_HF': [ + 'gpu_split', + 'max_seq_len', + 'cfg_cache', + ], 'AutoGPTQ': [ 'triton', 'no_inject_fused_attention', @@ -180,6 +189,42 @@ loaders_samplers = { 'ban_eos_token', 'auto_max_new_tokens', }, + 'ExLlamav2': { + 'temperature', + 'top_p', + 'top_k', + 'repetition_penalty', + 'repetition_penalty_range', + 'seed', + 'ban_eos_token', + 'auto_max_new_tokens', + }, + 'ExLlamav2_HF': { + 'temperature', + 'top_p', + 'top_k', + 'typical_p', + 'epsilon_cutoff', + 'eta_cutoff', + 'tfs', + 'top_a', + 'repetition_penalty', + 'repetition_penalty_range', + 'encoder_repetition_penalty', + 'no_repeat_ngram_size', + 'min_length', + 'seed', + 'do_sample', + 'mirostat_mode', + 'mirostat_tau', + 'mirostat_eta', + 'guidance_scale', + 'negative_prompt', + 'ban_eos_token', + 'add_bos_token', + 'skip_special_tokens', + 'auto_max_new_tokens', + }, 'AutoGPTQ': { 'temperature', 'top_p', diff --git a/modules/models.py b/modules/models.py index e336f155..39a63e1f 100644 --- a/modules/models.py +++ b/modules/models.py @@ -59,6 +59,8 @@ def load_model(model_name, loader=None): 'RWKV': RWKV_loader, 'ExLlama': ExLlama_loader, 'ExLlama_HF': ExLlama_HF_loader, + 'ExLlamav2': ExLlamav2_loader, + 'ExLlamav2_HF': ExLlamav2_HF_loader, 'ctransformers': ctransformers_loader, } @@ -329,6 +331,19 @@ def ExLlama_HF_loader(model_name): return ExllamaHF.from_pretrained(model_name) +def ExLlamav2_loader(model_name): + from modules.exllamav2 import Exllamav2Model + + model, tokenizer = Exllamav2Model.from_pretrained(model_name) + return model, tokenizer + + +def ExLlamav2_HF_loader(model_name): + from modules.exllamav2_hf import Exllamav2HF + + return Exllamav2HF.from_pretrained(model_name) + + def get_max_memory_dict(): max_memory = {} if shared.args.gpu_memory: diff --git a/modules/shared.py b/modules/shared.py index 8c73e609..829d7c01 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -219,6 +219,10 @@ def fix_loader_name(name): return 'ExLlama' elif name in ['exllama-hf', 'exllama_hf', 'exllama hf', 'ex-llama-hf', 'ex_llama_hf']: return 'ExLlama_HF' + elif name in ['exllamav2', 'exllama-v2', 'ex_llama-v2', 'exlamav2', 'exlama-v2']: + return 'ExLlamav2' + elif name in ['exllamav2-hf', 'exllamav2_hf', 'exllama-v2-hf', 'exllama_v2_hf', 'exllama-v2_hf']: + return 'ExLlamav2_HF' elif name in ['ctransformers', 'ctranforemrs', 'ctransformer']: return 'ctransformers' diff --git a/modules/text_generation.py b/modules/text_generation.py index ba3f7e69..67833d8c 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -42,7 +42,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap yield '' return - if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'CtransformersModel']: + if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel']: generate_func = generate_reply_custom else: generate_func = generate_reply_HF @@ -106,9 +106,10 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None): - if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'CtransformersModel']: + if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'CtransformersModel', 'Exllamav2Model']: input_ids = shared.tokenizer.encode(str(prompt)) - input_ids = np.array(input_ids).reshape(1, len(input_ids)) + if shared.model.__class__.__name__ not in ['Exllamav2Model']: + input_ids = np.array(input_ids).reshape(1, len(input_ids)) else: input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens) @@ -120,7 +121,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if truncation_length is not None: input_ids = input_ids[:, -truncation_length:] - if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'CtransformersModel'] or shared.args.cpu: + if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel'] or shared.args.cpu: return input_ids elif shared.args.deepspeed: return input_ids.to(device=local_rank) diff --git a/requirements.txt b/requirements.txt index 5ad831c4..09dfccfb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,9 @@ accelerate==0.22.* colorama datasets einops +exllamav2==0.0.0 markdown +ninja numpy==1.24 optimum==1.12.0 pandas diff --git a/requirements_nocuda.txt b/requirements_nocuda.txt index e44ddd30..891aa7e8 100644 --- a/requirements_nocuda.txt +++ b/requirements_nocuda.txt @@ -8,7 +8,9 @@ accelerate==0.22.* colorama datasets einops +exllamav2==0.0.0 markdown +ninja numpy==1.24 optimum==1.12.0 pandas