Make the bos token optional

This commit is contained in:
oobabooga 2023-04-10 16:44:22 -03:00
parent 4961f43702
commit bd04ff27ad
3 changed files with 12 additions and 4 deletions

View File

@ -35,6 +35,7 @@ settings = {
'greeting': 'Hello there!',
'end_of_turn': '',
'stop_at_newline': False,
'add_bos_token': True,
'chat_prompt_size': 2048,
'chat_prompt_size_min': 0,
'chat_prompt_size_max': 2048,

View File

@ -22,7 +22,7 @@ def get_max_prompt_length(tokens):
return max_length
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
def encode(prompt, tokens_to_generate=0, add_special_tokens=True, add_bos_token=True):
if any((shared.is_RWKV, shared.is_llamacpp)):
input_ids = shared.tokenizer.encode(str(prompt))
input_ids = np.array(input_ids).reshape(1, len(input_ids))
@ -30,6 +30,12 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
else:
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
# This is a hack for making replies more creative.
if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id:
input_ids = input_ids[:, 1:]
# Llama adds this extra token when the first character is '\n', and this
# compromises the stopping criteria, so we just remove it
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
input_ids = input_ids[:, 1:]
@ -158,7 +164,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return
input_ids = encode(question, generate_state['max_new_tokens'])
input_ids = encode(question, generate_state['max_new_tokens'], add_bos_token=generate_state['add_bos_token'])
original_input_ids = input_ids
output = input_ids[0]

View File

@ -233,7 +233,7 @@ def create_model_menus():
def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts', 'add_bos_token']:
generate_params[k] = shared.settings[k]
shared.gradio['generate_state'] = gr.State(generate_params)
@ -273,6 +273,7 @@ def create_settings_menus(default_preset):
with gr.Column():
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
with gr.Accordion('Soft prompt', open=False):
with gr.Row():
@ -610,7 +611,7 @@ def create_interface():
d[key] = value
return d
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'add_bos_token', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
if k not in shared.gradio:
continue
if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]: