Add Classifier Free Guidance (CFG) for Transformers/ExLlama (#3325)

This commit is contained in:
oobabooga 2023-08-06 17:22:48 -03:00 committed by GitHub
parent 5134878344
commit 0af10ab49b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 131 additions and 42 deletions

View File

@ -63,6 +63,8 @@ async def run(user_input, history):
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1,
'add_bos_token': True,

View File

@ -57,6 +57,8 @@ def run(user_input, history):
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1,
'add_bos_token': True,

View File

@ -45,6 +45,8 @@ async def run(context):
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1,
'add_bos_token': True,

View File

@ -37,6 +37,8 @@ def run(prompt):
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1,
'add_bos_token': True,

View File

@ -43,6 +43,8 @@ def build_parameters(body, chat=False):
'mirostat_mode': int(body.get('mirostat_mode', 0)),
'mirostat_tau': float(body.get('mirostat_tau', 5)),
'mirostat_eta': float(body.get('mirostat_eta', 0.1)),
'guidance_scale': float(body.get('guidance_scale', 1)),
'negative_prompt': str(body.get('negative_prompt', '')),
'seed': int(body.get('seed', -1)),
'add_bos_token': bool(body.get('add_bos_token', True)),
'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),

View File

@ -33,6 +33,8 @@ default_req_params = {
'mirostat_mode': 0,
'mirostat_tau': 5.0,
'mirostat_eta': 0.1,
'guidance_scale': 1,
'negative_prompt': '',
'ban_eos_token': False,
'skip_special_tokens': True,
'custom_stopping_strings': '',

View File

@ -1,9 +1,11 @@
from pathlib import Path
import torch.nn.functional as F
from torch import version as torch_version
from modules import shared
from modules.logging_colors import logger
from modules.models import clear_torch_cache
from modules.text_generation import get_max_prompt_length
try:
@ -78,6 +80,21 @@ class ExllamaModel:
return result, result
def generate_with_streaming(self, prompt, state):
# The cache batch size must be 2 for CFG and 1 otherwise
if state['guidance_scale'] == 1:
if self.cache.batch_size == 2:
del self.cache
clear_torch_cache()
self.cache = ExLlamaCache(self.model)
self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
else:
if self.cache.batch_size == 1:
del self.cache
clear_torch_cache()
self.cache = ExLlamaCache(self.model, batch_size=2)
self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
self.generator.settings.temperature = state['temperature']
self.generator.settings.top_p = state['top_p']
self.generator.settings.top_k = state['top_k']
@ -89,6 +106,8 @@ class ExllamaModel:
else:
self.generator.disallow_tokens(None)
# Case 1: no CFG
if state['guidance_scale'] == 1:
self.generator.end_beam_search()
# Tokenizing the input
@ -102,6 +121,7 @@ class ExllamaModel:
self.generator.gen_begin_reuse(ids)
initial_len = self.generator.sequence[0].shape[0]
has_leading_space = False
for i in range(max_new_tokens):
token = self.generator.gen_single_token()
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith(''):
@ -115,6 +135,43 @@ class ExllamaModel:
if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything:
break
# Case 2: CFG
else:
alpha = state['guidance_scale']
prompts = [prompt, state['negative_prompt'] or '']
ids, mask = self.tokenizer.encode(prompts, return_mask=True)
if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - ids[0].shape[-1]
else:
max_new_tokens = state['max_new_tokens']
self.generator.gen_begin(ids, mask=mask)
initial_len = self.generator.sequence[0].shape[0]
has_leading_space = False
for i in range(max_new_tokens):
logits = self.model.forward(self.generator.sequence[:, -1:], self.cache, input_mask=mask)
self.generator.apply_rep_penalty(logits)
logits = F.log_softmax(logits, dim=-1)
logits_mixed = alpha * logits[0] + (1 - alpha) * logits[1]
token, _ = self.generator.sample_current(logits_mixed)
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith(''):
has_leading_space = True
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
if has_leading_space:
decoded_text = ' ' + decoded_text
yield decoded_text
if token.item() == self.tokenizer.eos_token_id or shared.stop_everything:
break
batch_token = token.repeat(2, 1)
self.generator.gen_accept_token(batch_token)
def generate(self, prompt, state):
output = ''
for output in self.generate_with_streaming(prompt, state):

View File

@ -47,12 +47,11 @@ class ExllamaHF(PreTrainedModel):
return torch.device(0)
def __call__(self, *args, **kwargs):
# TODO: Some decoding methods (such as Contrastive Search) may not work at this time
assert len(args) == 0, 'no *args should be passed to forward'
input_ids = args[0] if len(args) > 0 else kwargs['input_ids']
use_cache = kwargs.get('use_cache', True)
labels = kwargs.get('labels', None)
seq = kwargs['input_ids'][0].tolist()
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
cache = kwargs.get('past_key_values', None)
seq = input_ids[0].tolist()
if labels is None:
if cache is None:
@ -60,7 +59,7 @@ class ExllamaHF(PreTrainedModel):
cache = self.ex_cache
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True, lora=self.lora)
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device)
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(input_ids.device)
else:
if cache is None:
self.ex_cache.current_seq_len = 0

View File

@ -49,12 +49,11 @@ class LlamacppHF(PreTrainedModel):
return torch.device(0)
def __call__(self, *args, **kwargs):
# TODO: Some decoding methods (such as Contrastive Search) may not work at this time
assert len(args) == 0, 'no *args should be passed to forward'
input_ids = args[0] if len(args) > 0 else kwargs['input_ids']
use_cache = kwargs.get('use_cache', True)
labels = kwargs.get('labels', None)
seq = kwargs['input_ids'][0].tolist()
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
cache = kwargs.get('past_key_values', None)
seq = input_ids[0].tolist()
# Make the forward call
seq_tensor = torch.tensor(seq)
@ -70,7 +69,7 @@ class LlamacppHF(PreTrainedModel):
self.model.reset()
self.model.eval(seq)
logits = torch.tensor(self.model.eval_logits)
logits = logits.view(1, logits.shape[0], logits.shape[1]).to(kwargs['input_ids'].device)
logits = logits.view(1, logits.shape[0], logits.shape[1]).to(input_ids.device)
self.cache = seq_tensor

View File

@ -115,6 +115,8 @@ loaders_samplers = {
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'guidance_scale',
'negative_prompt',
'ban_eos_token',
'add_bos_token',
'skip_special_tokens',
@ -152,6 +154,8 @@ loaders_samplers = {
'repetition_penalty',
'repetition_penalty_range',
'seed',
'guidance_scale',
'negative_prompt',
'ban_eos_token',
'auto_max_new_tokens',
},
@ -178,6 +182,8 @@ loaders_samplers = {
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'guidance_scale',
'negative_prompt',
'ban_eos_token',
'add_bos_token',
'skip_special_tokens',
@ -206,6 +212,8 @@ loaders_samplers = {
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'guidance_scale',
'negative_prompt',
'ban_eos_token',
'add_bos_token',
'skip_special_tokens',

View File

@ -9,6 +9,7 @@ def default_preset():
'do_sample': True,
'temperature': 1,
'top_p': 1,
'top_k': 0,
'typical_p': 1,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
@ -17,19 +18,23 @@ def default_preset():
'repetition_penalty': 1,
'repetition_penalty_range': 0,
'encoder_repetition_penalty': 1,
'top_k': 0,
'num_beams': 1,
'penalty_alpha': 0,
'min_length': 0,
'length_penalty': 1,
'no_repeat_ngram_size': 0,
'early_stopping': False,
'min_length': 0,
'guidance_scale': 1,
'mirostat_mode': 0,
'mirostat_tau': 5.0,
'mirostat_eta': 0.1,
'penalty_alpha': 0,
'num_beams': 1,
'length_penalty': 1,
'early_stopping': False,
}
def presets_params():
return [k for k in default_preset()]
def load_preset(name):
generate_params = default_preset()
if name not in ['None', None, '']:
@ -51,12 +56,12 @@ def load_preset_memoized(name):
def load_preset_for_ui(name, state):
generate_params = load_preset(name)
state.update(generate_params)
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
return state, *[generate_params[k] for k in presets_params()]
def generate_preset_yaml(state):
defaults = default_preset()
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
data = {k: state[k] for k in presets_params()}
# Remove entries that are identical to the defaults
for k in list(data.keys()):

View File

@ -42,6 +42,7 @@ settings = {
'max_new_tokens_max': 4096,
'auto_max_new_tokens': False,
'seed': -1,
'negative_prompt': '',
'character': 'None',
'name1': 'You',
'name2': 'Assistant',

View File

@ -226,9 +226,12 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
generate_params[k] = state[k]
if state['negative_prompt'] != '':
generate_params['negative_prompt_ids'] = encode(state['negative_prompt'])
for k in ['epsilon_cutoff', 'eta_cutoff']:
if state[k] > 0:
generate_params[k] = state[k] * 1e-4

View File

@ -100,6 +100,8 @@ def list_interface_input_elements():
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'negative_prompt',
'guidance_scale',
'add_bos_token',
'ban_eos_token',
'truncation_length',

View File

@ -15,10 +15,10 @@ safetensors==0.3.1
scipy
sentencepiece
tensorboard
transformers==4.31.*
tqdm
wandb
git+https://github.com/huggingface/peft@96c0277a1b9a381b10ab34dbf84917f9b3b992e6
git+https://github.com/huggingface/transformers@d533465150532b0c5de167b574e59f64c68b1154
bitsandbytes==0.41.1; platform_system != "Windows"
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.3.0/auto_gptq-0.3.0+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"

View File

@ -229,7 +229,7 @@ def create_model_menus():
shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0)
shared.gradio['autogptq_info'] = gr.Markdown('* ExLlama_HF is recommended over AutoGPTQ for models derived from LLaMA.')
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
shared.gradio['max_seq_len'] = gr.Slider(label='max_seq_len', minimum=2048, maximum=16384, step=256, info='Maximum sequence length.', value=shared.args.max_seq_len)
shared.gradio['max_seq_len'] = gr.Slider(label='max_seq_len', minimum=0, maximum=16384, step=256, info='Maximum sequence length.', value=shared.args.max_seq_len)
shared.gradio['compress_pos_emb'] = gr.Slider(label='compress_pos_emb', minimum=1, maximum=8, step=1, info='Positional embeddings compression factor. Should typically be set to max_seq_len / 2048.', value=shared.args.compress_pos_emb)
shared.gradio['alpha_value'] = gr.Slider(label='alpha_value', minimum=1, maximum=32, step=1, info='Positional embeddings alpha factor for NTK RoPE scaling. Scaling is not identical to embedding compression. Use either this or compress_pos_emb, not both.', value=shared.args.alpha_value)
@ -408,6 +408,8 @@ def create_settings_menus(default_preset):
with gr.Box():
with gr.Row():
with gr.Column():
shared.gradio['guidance_scale'] = gr.Slider(-0.5, 2.5, step=0.05, value=generate_params['guidance_scale'], label='guidance_scale', info='For CFG. 1.5 is a good value.')
shared.gradio['negative_prompt'] = gr.Textbox(value=shared.settings['negative_prompt'], label='Negative prompt')
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
@ -433,7 +435,7 @@ def create_settings_menus(default_preset):
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
filter_by_loader.change(loaders.blacklist_samplers, filter_by_loader, gradio(loaders.list_all_samplers()), show_progress=False)
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a'))
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state') + gradio(presets.presets_params()))
def create_file_saving_menus():

View File

@ -5,6 +5,7 @@ max_new_tokens_min: 1
max_new_tokens_max: 4096
auto_max_new_tokens: false
seed: -1
negative_prompt: ''
character: None
name1: You
name2: Assistant