From d6934bc7bc79d7e30629414d1f46faae404bbff9 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 24 Aug 2023 16:27:36 -0300 Subject: [PATCH] Implement CFG for ExLlama_HF (#3666) --- README.md | 1 + modules/exllama_hf.py | 57 ++++++++++++++++++++------- modules/llamacpp_hf.py | 80 +++++++++++++++++++++++++++++++++----- modules/loaders.py | 3 ++ modules/models_settings.py | 4 +- modules/shared.py | 1 + modules/ui.py | 1 + modules/ui_model_menu.py | 1 + 8 files changed, 122 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 9e63163e..3c58d0ec 100644 --- a/README.md +++ b/README.md @@ -304,6 +304,7 @@ Optionally, you can use the following command-line flags: |------------------|-------------| |`--gpu-split` | Comma-separated list of VRAM (in GB) to use per GPU device for model layers, e.g. `20,7,7` | |`--max_seq_len MAX_SEQ_LEN` | Maximum sequence length. | +|`--cfg-cache` | ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama. | #### GPTQ-for-LLaMa diff --git a/modules/exllama_hf.py b/modules/exllama_hf.py index ebafb4f7..129ee52e 100644 --- a/modules/exllama_hf.py +++ b/modules/exllama_hf.py @@ -29,10 +29,16 @@ class ExllamaHF(PreTrainedModel): super().__init__(PretrainedConfig()) self.ex_config = config self.ex_model = ExLlama(self.ex_config) - self.ex_cache = ExLlamaCache(self.ex_model) self.generation_config = GenerationConfig() self.lora = None + self.ex_cache = ExLlamaCache(self.ex_model) + self.past_seq = None + + if shared.args.cfg_cache: + self.ex_cache_negative = ExLlamaCache(self.ex_model) + self.past_seq_negative = None + def _validate_model_class(self): pass @@ -47,25 +53,46 @@ class ExllamaHF(PreTrainedModel): return torch.device(0) def __call__(self, *args, **kwargs): - input_ids = args[0] if len(args) > 0 else kwargs['input_ids'] use_cache = kwargs.get('use_cache', True) labels = kwargs.get('labels', None) - cache = kwargs.get('past_key_values', None) - seq = input_ids[0].tolist() + past_key_values = kwargs.get('past_key_values', None) - if labels is None: - if cache is None: - self.ex_cache.current_seq_len = 0 - cache = self.ex_cache - self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True, lora=self.lora) + if len(args) > 0: + if not shared.args.cfg_cache: + logger.error("Please enable the cfg-cache option to use CFG with ExLlama_HF.") + return - logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(input_ids.device) + input_ids = args[0] + is_negative = True + past_seq = self.past_seq_negative + ex_cache = self.ex_cache_negative else: - if cache is None: - self.ex_cache.current_seq_len = 0 - cache = self.ex_cache + input_ids = kwargs['input_ids'] + is_negative = False + past_seq = self.past_seq + ex_cache = self.ex_cache - logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), cache, last_id_only=False, lora=self.lora) + seq = input_ids[0].tolist() + if is_negative and past_key_values is not None: + seq = past_key_values + seq + + seq_tensor = torch.tensor(seq) + + # Make the forward call + if labels is None: + if past_seq is None or not torch.equal(past_seq, seq_tensor[:-1]): + ex_cache.current_seq_len = 0 + self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), ex_cache, preprocess_only=True, lora=self.lora) + + logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), ex_cache, lora=self.lora).to(input_ids.device) + else: + ex_cache.current_seq_len = 0 + logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache, last_id_only=False, lora=self.lora) + + if is_negative: + self.past_seq_negative = seq_tensor + else: + self.past_seq = seq_tensor loss = None if labels is not None: @@ -80,7 +107,7 @@ class ExllamaHF(PreTrainedModel): shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) - return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None, loss=loss) + return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py index 10c30112..3c7314d1 100644 --- a/modules/llamacpp_hf.py +++ b/modules/llamacpp_hf.py @@ -33,7 +33,22 @@ class LlamacppHF(PreTrainedModel): super().__init__(PretrainedConfig()) self.model = model self.generation_config = GenerationConfig() - self.cache = None + + self.past_seq = None + self.llamacpp_cache = { + 'n_tokens': self.model.n_tokens, + 'input_ids': self.model.input_ids, + 'scores': self.model.scores + } + + if shared.args.cfg_cache: + logger.warning('CFG is currently bugged and not functional for llamacpp_HF. Contributions are welcome.') + self.past_seq_negative = None + self.llamacpp_cache_negative = { + 'n_tokens': self.model.n_tokens, + 'input_ids': self.model.input_ids.copy(), + 'scores': self.model.scores.copy() + } def _validate_model_class(self): pass @@ -44,36 +59,83 @@ class LlamacppHF(PreTrainedModel): def prepare_inputs_for_generation(self, input_ids, **kwargs): return {'input_ids': input_ids, **kwargs} + def save_cache(self): + self.llamacpp_cache.update({ + 'n_tokens': self.model.n_tokens, + 'input_ids': self.model.input_ids, + 'scores': self.model.scores + }) + + def save_negative_cache(self): + self.llamacpp_cache_negative.update({ + 'n_tokens': self.model.n_tokens, + 'input_ids': self.model.input_ids, + 'scores': self.model.scores + }) + + def load_cache(self): + self.model.n_tokens = self.llamacpp_cache['n_tokens'] + self.model.input_ids = self.llamacpp_cache['input_ids'] + self.model.scores = self.llamacpp_cache['scores'] + + def load_negative_cache(self): + self.model.n_tokens = self.llamacpp_cache_negative['n_tokens'] + self.model.input_ids = self.llamacpp_cache_negative['input_ids'] + self.model.scores = self.llamacpp_cache_negative['scores'] + @property def device(self) -> torch.device: return torch.device(0) def __call__(self, *args, **kwargs): - input_ids = args[0] if len(args) > 0 else kwargs['input_ids'] use_cache = kwargs.get('use_cache', True) labels = kwargs.get('labels', None) - cache = kwargs.get('past_key_values', None) + past_key_values = kwargs.get('past_key_values', None) + + if len(args) > 0: + if not shared.args.cfg_cache: + logger.error("Please enable the cfg-cache option to use CFG with llamacpp_HF.") + logger.warning('CFG is currently bugged and not functional for llamacpp_HF. Contributions are welcome.') + return + + input_ids = args[0] + is_negative = True + past_seq = self.past_seq_negative + self.load_negative_cache() + else: + input_ids = kwargs['input_ids'] + is_negative = False + past_seq = self.past_seq + self.load_cache() + seq = input_ids[0].tolist() + if is_negative and past_key_values is not None: + seq = past_key_values + seq + + seq_tensor = torch.tensor(seq) # Make the forward call - seq_tensor = torch.tensor(seq) if labels is None: - if self.cache is None or not torch.equal(self.cache, seq_tensor[:-1]): + if past_seq is None or not torch.equal(past_seq, seq_tensor[:-1]): self.model.reset() self.model.eval(seq) else: self.model.eval([seq[-1]]) - logits = torch.tensor(self.model.scores[self.model.n_tokens - 1, :]).view(1, 1, -1).to(kwargs['input_ids'].device) + logits = torch.tensor(self.model.scores[self.model.n_tokens - 1, :]).view(1, 1, -1).to(input_ids.device) else: self.model.reset() self.model.eval(seq) logits = torch.tensor(self.model.eval_logits) logits = logits.view(1, logits.shape[0], logits.shape[1]).to(input_ids.device) - self.cache = seq_tensor + if is_negative: + self.save_negative_cache() + self.past_seq_negative = seq_tensor + else: + self.save_cache() + self.past_seq = seq_tensor - # Based on transformers/models/llama/modeling_llama.py loss = None if labels is not None: # Shift so that tokens < n predict n @@ -87,7 +149,7 @@ class LlamacppHF(PreTrainedModel): shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) - return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None, loss=loss) + return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): diff --git a/modules/loaders.py b/modules/loaders.py index 472e8ddb..b8660a46 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -29,6 +29,7 @@ loaders_and_params = OrderedDict({ 'max_seq_len', 'alpha_value', 'compress_pos_emb', + 'cfg_cache', 'exllama_HF_info', ], 'ExLlama': [ @@ -157,6 +158,8 @@ loaders_samplers = { 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'guidance_scale', + 'negative_prompt', 'ban_eos_token', 'add_bos_token', 'skip_special_tokens', diff --git a/modules/models_settings.py b/modules/models_settings.py index 06a41da4..5efde34b 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -91,8 +91,8 @@ def apply_model_settings_to_state(model, state): if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0: loader = 'AutoGPTQ' - # If the user is using an alternative GPTQ loader, let them keep using it - if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'ExLlama', 'ExLlama_HF']): + # If the user is using an alternative loader for the same model type, let them keep using it + if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'ExLlama', 'ExLlama_HF']) and not (loader == 'llama.cpp' and state['loader'] in ['llamacpp_HF', 'ctransformers']): state['loader'] = loader for k in model_settings: diff --git a/modules/shared.py b/modules/shared.py index 385b99da..c89c906b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -147,6 +147,7 @@ parser.add_argument('--disable_exllama', action='store_true', help='Disable ExLl # ExLlama parser.add_argument('--gpu-split', type=str, help="Comma-separated list of VRAM (in GB) to use per GPU device for model layers, e.g. 20,7,7") parser.add_argument('--max_seq_len', type=int, default=2048, help="Maximum sequence length.") +parser.add_argument('--cfg-cache', action='store_true', help="ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama.") # DeepSpeed parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.') diff --git a/modules/ui.py b/modules/ui.py index 15f24d85..f6e9ac10 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -63,6 +63,7 @@ def list_model_elements(): 'no_inject_fused_mlp', 'no_use_cuda_fp16', 'disable_exllama', + 'cfg_cache', 'threads', 'n_batch', 'no_mmap', diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index e217bee1..05fe3af7 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -111,6 +111,7 @@ def create_ui(): shared.gradio['low_vram'] = gr.Checkbox(label="low-vram", value=shared.args.low_vram) shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock) shared.gradio['mul_mat_q'] = gr.Checkbox(label="mul_mat_q", value=shared.args.mul_mat_q) + shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Create an additional cache for CFG negative prompts.') shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='Split the model across multiple GPUs, comma-separated list of proportions, e.g. 18,17') shared.gradio['llama_cpp_seed'] = gr.Number(label='Seed (0 for random)', value=shared.args.llama_cpp_seed) shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Make sure to inspect the .py files inside the model folder before loading it with this option enabled.')