From 0af10ab49bfc1cab80d0126707321a58bd9e3485 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 6 Aug 2023 17:22:48 -0300 Subject: [PATCH] Add Classifier Free Guidance (CFG) for Transformers/ExLlama (#3325) --- api-examples/api-example-chat-stream.py | 2 + api-examples/api-example-chat.py | 2 + api-examples/api-example-stream.py | 2 + api-examples/api-example.py | 2 + extensions/api/util.py | 2 + extensions/openai/defaults.py | 2 + modules/exllama.py | 97 ++++++++++++++++++++----- modules/exllama_hf.py | 9 +-- modules/llamacpp_hf.py | 9 +-- modules/loaders.py | 8 ++ modules/presets.py | 21 ++++-- modules/shared.py | 1 + modules/text_generation.py | 5 +- modules/ui.py | 2 + requirements.txt | 2 +- server.py | 6 +- settings-template.yaml | 1 + 17 files changed, 131 insertions(+), 42 deletions(-) diff --git a/api-examples/api-example-chat-stream.py b/api-examples/api-example-chat-stream.py index 2914d451..a774f907 100644 --- a/api-examples/api-example-chat-stream.py +++ b/api-examples/api-example-chat-stream.py @@ -63,6 +63,8 @@ async def run(user_input, history): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, + 'guidance_scale': 1, + 'negative_prompt': '', 'seed': -1, 'add_bos_token': True, diff --git a/api-examples/api-example-chat.py b/api-examples/api-example-chat.py index e2797f1e..824bf3a0 100644 --- a/api-examples/api-example-chat.py +++ b/api-examples/api-example-chat.py @@ -57,6 +57,8 @@ def run(user_input, history): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, + 'guidance_scale': 1, + 'negative_prompt': '', 'seed': -1, 'add_bos_token': True, diff --git a/api-examples/api-example-stream.py b/api-examples/api-example-stream.py index 175275f9..bf5eabac 100644 --- a/api-examples/api-example-stream.py +++ b/api-examples/api-example-stream.py @@ -45,6 +45,8 @@ async def run(context): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, + 'guidance_scale': 1, + 'negative_prompt': '', 'seed': -1, 'add_bos_token': True, diff --git a/api-examples/api-example.py b/api-examples/api-example.py index 7f8bc1d2..16029807 100644 --- a/api-examples/api-example.py +++ b/api-examples/api-example.py @@ -37,6 +37,8 @@ def run(prompt): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, + 'guidance_scale': 1, + 'negative_prompt': '', 'seed': -1, 'add_bos_token': True, diff --git a/extensions/api/util.py b/extensions/api/util.py index ef58a70f..2654d046 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -43,6 +43,8 @@ def build_parameters(body, chat=False): 'mirostat_mode': int(body.get('mirostat_mode', 0)), 'mirostat_tau': float(body.get('mirostat_tau', 5)), 'mirostat_eta': float(body.get('mirostat_eta', 0.1)), + 'guidance_scale': float(body.get('guidance_scale', 1)), + 'negative_prompt': str(body.get('negative_prompt', '')), 'seed': int(body.get('seed', -1)), 'add_bos_token': bool(body.get('add_bos_token', True)), 'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))), diff --git a/extensions/openai/defaults.py b/extensions/openai/defaults.py index cb8308e7..ffef12d0 100644 --- a/extensions/openai/defaults.py +++ b/extensions/openai/defaults.py @@ -33,6 +33,8 @@ default_req_params = { 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1, + 'guidance_scale': 1, + 'negative_prompt': '', 'ban_eos_token': False, 'skip_special_tokens': True, 'custom_stopping_strings': '', diff --git a/modules/exllama.py b/modules/exllama.py index 00b37b9c..dc632a25 100644 --- a/modules/exllama.py +++ b/modules/exllama.py @@ -1,9 +1,11 @@ from pathlib import Path +import torch.nn.functional as F from torch import version as torch_version from modules import shared from modules.logging_colors import logger +from modules.models import clear_torch_cache from modules.text_generation import get_max_prompt_length try: @@ -78,6 +80,21 @@ class ExllamaModel: return result, result def generate_with_streaming(self, prompt, state): + + # The cache batch size must be 2 for CFG and 1 otherwise + if state['guidance_scale'] == 1: + if self.cache.batch_size == 2: + del self.cache + clear_torch_cache() + self.cache = ExLlamaCache(self.model) + self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache) + else: + if self.cache.batch_size == 1: + del self.cache + clear_torch_cache() + self.cache = ExLlamaCache(self.model, batch_size=2) + self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache) + self.generator.settings.temperature = state['temperature'] self.generator.settings.top_p = state['top_p'] self.generator.settings.top_k = state['top_k'] @@ -89,31 +106,71 @@ class ExllamaModel: else: self.generator.disallow_tokens(None) - self.generator.end_beam_search() + # Case 1: no CFG + if state['guidance_scale'] == 1: + self.generator.end_beam_search() - # Tokenizing the input - ids = self.generator.tokenizer.encode(prompt) - ids = ids[:, -get_max_prompt_length(state):] - if state['auto_max_new_tokens']: - max_new_tokens = state['truncation_length'] - ids.shape[-1] + # Tokenizing the input + ids = self.generator.tokenizer.encode(prompt) + ids = ids[:, -get_max_prompt_length(state):] + if state['auto_max_new_tokens']: + max_new_tokens = state['truncation_length'] - ids.shape[-1] + else: + max_new_tokens = state['max_new_tokens'] + + self.generator.gen_begin_reuse(ids) + initial_len = self.generator.sequence[0].shape[0] + has_leading_space = False + + for i in range(max_new_tokens): + token = self.generator.gen_single_token() + if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'): + has_leading_space = True + + decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:]) + if has_leading_space: + decoded_text = ' ' + decoded_text + + yield decoded_text + if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything: + break + + # Case 2: CFG else: - max_new_tokens = state['max_new_tokens'] + alpha = state['guidance_scale'] + prompts = [prompt, state['negative_prompt'] or ''] - self.generator.gen_begin_reuse(ids) - initial_len = self.generator.sequence[0].shape[0] - has_leading_space = False - for i in range(max_new_tokens): - token = self.generator.gen_single_token() - if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'): - has_leading_space = True + ids, mask = self.tokenizer.encode(prompts, return_mask=True) + if state['auto_max_new_tokens']: + max_new_tokens = state['truncation_length'] - ids[0].shape[-1] + else: + max_new_tokens = state['max_new_tokens'] - decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:]) - if has_leading_space: - decoded_text = ' ' + decoded_text + self.generator.gen_begin(ids, mask=mask) + initial_len = self.generator.sequence[0].shape[0] + has_leading_space = False - yield decoded_text - if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything: - break + for i in range(max_new_tokens): + logits = self.model.forward(self.generator.sequence[:, -1:], self.cache, input_mask=mask) + self.generator.apply_rep_penalty(logits) + + logits = F.log_softmax(logits, dim=-1) + logits_mixed = alpha * logits[0] + (1 - alpha) * logits[1] + + token, _ = self.generator.sample_current(logits_mixed) + if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'): + has_leading_space = True + + decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:]) + if has_leading_space: + decoded_text = ' ' + decoded_text + + yield decoded_text + if token.item() == self.tokenizer.eos_token_id or shared.stop_everything: + break + + batch_token = token.repeat(2, 1) + self.generator.gen_accept_token(batch_token) def generate(self, prompt, state): output = '' diff --git a/modules/exllama_hf.py b/modules/exllama_hf.py index fd775b4a..ebafb4f7 100644 --- a/modules/exllama_hf.py +++ b/modules/exllama_hf.py @@ -47,12 +47,11 @@ class ExllamaHF(PreTrainedModel): return torch.device(0) def __call__(self, *args, **kwargs): - # TODO: Some decoding methods (such as Contrastive Search) may not work at this time - assert len(args) == 0, 'no *args should be passed to forward' + input_ids = args[0] if len(args) > 0 else kwargs['input_ids'] use_cache = kwargs.get('use_cache', True) labels = kwargs.get('labels', None) - seq = kwargs['input_ids'][0].tolist() - cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None + cache = kwargs.get('past_key_values', None) + seq = input_ids[0].tolist() if labels is None: if cache is None: @@ -60,7 +59,7 @@ class ExllamaHF(PreTrainedModel): cache = self.ex_cache self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True, lora=self.lora) - logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device) + logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(input_ids.device) else: if cache is None: self.ex_cache.current_seq_len = 0 diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py index e9f4ade6..df9e0b2e 100644 --- a/modules/llamacpp_hf.py +++ b/modules/llamacpp_hf.py @@ -49,12 +49,11 @@ class LlamacppHF(PreTrainedModel): return torch.device(0) def __call__(self, *args, **kwargs): - # TODO: Some decoding methods (such as Contrastive Search) may not work at this time - assert len(args) == 0, 'no *args should be passed to forward' + input_ids = args[0] if len(args) > 0 else kwargs['input_ids'] use_cache = kwargs.get('use_cache', True) labels = kwargs.get('labels', None) - seq = kwargs['input_ids'][0].tolist() - cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None + cache = kwargs.get('past_key_values', None) + seq = input_ids[0].tolist() # Make the forward call seq_tensor = torch.tensor(seq) @@ -70,7 +69,7 @@ class LlamacppHF(PreTrainedModel): self.model.reset() self.model.eval(seq) logits = torch.tensor(self.model.eval_logits) - logits = logits.view(1, logits.shape[0], logits.shape[1]).to(kwargs['input_ids'].device) + logits = logits.view(1, logits.shape[0], logits.shape[1]).to(input_ids.device) self.cache = seq_tensor diff --git a/modules/loaders.py b/modules/loaders.py index aa1afcb8..519e47a7 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -115,6 +115,8 @@ loaders_samplers = { 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'guidance_scale', + 'negative_prompt', 'ban_eos_token', 'add_bos_token', 'skip_special_tokens', @@ -152,6 +154,8 @@ loaders_samplers = { 'repetition_penalty', 'repetition_penalty_range', 'seed', + 'guidance_scale', + 'negative_prompt', 'ban_eos_token', 'auto_max_new_tokens', }, @@ -178,6 +182,8 @@ loaders_samplers = { 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'guidance_scale', + 'negative_prompt', 'ban_eos_token', 'add_bos_token', 'skip_special_tokens', @@ -206,6 +212,8 @@ loaders_samplers = { 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'guidance_scale', + 'negative_prompt', 'ban_eos_token', 'add_bos_token', 'skip_special_tokens', diff --git a/modules/presets.py b/modules/presets.py index 072b15fd..32b7f71c 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -9,6 +9,7 @@ def default_preset(): 'do_sample': True, 'temperature': 1, 'top_p': 1, + 'top_k': 0, 'typical_p': 1, 'epsilon_cutoff': 0, 'eta_cutoff': 0, @@ -17,19 +18,23 @@ def default_preset(): 'repetition_penalty': 1, 'repetition_penalty_range': 0, 'encoder_repetition_penalty': 1, - 'top_k': 0, - 'num_beams': 1, - 'penalty_alpha': 0, - 'min_length': 0, - 'length_penalty': 1, 'no_repeat_ngram_size': 0, - 'early_stopping': False, + 'min_length': 0, + 'guidance_scale': 1, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1, + 'penalty_alpha': 0, + 'num_beams': 1, + 'length_penalty': 1, + 'early_stopping': False, } +def presets_params(): + return [k for k in default_preset()] + + def load_preset(name): generate_params = default_preset() if name not in ['None', None, '']: @@ -51,12 +56,12 @@ def load_preset_memoized(name): def load_preset_for_ui(name, state): generate_params = load_preset(name) state.update(generate_params) - return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']] + return state, *[generate_params[k] for k in presets_params()] def generate_preset_yaml(state): defaults = default_preset() - data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']} + data = {k: state[k] for k in presets_params()} # Remove entries that are identical to the defaults for k in list(data.keys()): diff --git a/modules/shared.py b/modules/shared.py index 51017a1b..be5be109 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -42,6 +42,7 @@ settings = { 'max_new_tokens_max': 4096, 'auto_max_new_tokens': False, 'seed': -1, + 'negative_prompt': '', 'character': 'None', 'name1': 'You', 'name2': 'Assistant', diff --git a/modules/text_generation.py b/modules/text_generation.py index 7507a731..df9d708b 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -226,9 +226,12 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False): def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False): generate_params = {} - for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']: + for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']: generate_params[k] = state[k] + if state['negative_prompt'] != '': + generate_params['negative_prompt_ids'] = encode(state['negative_prompt']) + for k in ['epsilon_cutoff', 'eta_cutoff']: if state[k] > 0: generate_params[k] = state[k] * 1e-4 diff --git a/modules/ui.py b/modules/ui.py index eed2ef66..8a7f9f47 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -100,6 +100,8 @@ def list_interface_input_elements(): 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'negative_prompt', + 'guidance_scale', 'add_bos_token', 'ban_eos_token', 'truncation_length', diff --git a/requirements.txt b/requirements.txt index 5a46addd..9deadd48 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,10 +15,10 @@ safetensors==0.3.1 scipy sentencepiece tensorboard -transformers==4.31.* tqdm wandb git+https://github.com/huggingface/peft@96c0277a1b9a381b10ab34dbf84917f9b3b992e6 +git+https://github.com/huggingface/transformers@d533465150532b0c5de167b574e59f64c68b1154 bitsandbytes==0.41.1; platform_system != "Windows" https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows" https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.3.0/auto_gptq-0.3.0+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" diff --git a/server.py b/server.py index 0e1d199d..adff9669 100644 --- a/server.py +++ b/server.py @@ -229,7 +229,7 @@ def create_model_menus(): shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0) shared.gradio['autogptq_info'] = gr.Markdown('* ExLlama_HF is recommended over AutoGPTQ for models derived from LLaMA.') shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7') - shared.gradio['max_seq_len'] = gr.Slider(label='max_seq_len', minimum=2048, maximum=16384, step=256, info='Maximum sequence length.', value=shared.args.max_seq_len) + shared.gradio['max_seq_len'] = gr.Slider(label='max_seq_len', minimum=0, maximum=16384, step=256, info='Maximum sequence length.', value=shared.args.max_seq_len) shared.gradio['compress_pos_emb'] = gr.Slider(label='compress_pos_emb', minimum=1, maximum=8, step=1, info='Positional embeddings compression factor. Should typically be set to max_seq_len / 2048.', value=shared.args.compress_pos_emb) shared.gradio['alpha_value'] = gr.Slider(label='alpha_value', minimum=1, maximum=32, step=1, info='Positional embeddings alpha factor for NTK RoPE scaling. Scaling is not identical to embedding compression. Use either this or compress_pos_emb, not both.', value=shared.args.alpha_value) @@ -408,6 +408,8 @@ def create_settings_menus(default_preset): with gr.Box(): with gr.Row(): with gr.Column(): + shared.gradio['guidance_scale'] = gr.Slider(-0.5, 2.5, step=0.05, value=generate_params['guidance_scale'], label='guidance_scale', info='For CFG. 1.5 is a good value.') + shared.gradio['negative_prompt'] = gr.Textbox(value=shared.settings['negative_prompt'], label='Negative prompt') shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.') shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau') shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta') @@ -433,7 +435,7 @@ def create_settings_menus(default_preset): shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming') filter_by_loader.change(loaders.blacklist_samplers, filter_by_loader, gradio(loaders.list_all_samplers()), show_progress=False) - shared.gradio['preset_menu'].change(presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a')) + shared.gradio['preset_menu'].change(presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state') + gradio(presets.presets_params())) def create_file_saving_menus(): diff --git a/settings-template.yaml b/settings-template.yaml index 62e86371..a0c53b33 100644 --- a/settings-template.yaml +++ b/settings-template.yaml @@ -5,6 +5,7 @@ max_new_tokens_min: 1 max_new_tokens_max: 4096 auto_max_new_tokens: false seed: -1 +negative_prompt: '' character: None name1: You name2: Assistant