Add skip_special_tokens checkbox for Dolly model (#1218)

This commit is contained in:
oobabooga 2023-04-16 14:24:49 -03:00 committed by GitHub
parent a9c7ef4159
commit b937c9d8c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 35 additions and 15 deletions

View File

@ -42,9 +42,10 @@ async def run(context):
'early_stopping': False,
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
'custom_stopping_strings': [],
'ban_eos_token': False
'truncation_length': 2048,
'ban_eos_token': False,
'skip_special_tokens': True,
}
payload = json.dumps([context, params])
session = random_hash()

View File

@ -39,6 +39,7 @@ params = {
'custom_stopping_strings': [],
'truncation_length': 2048,
'ban_eos_token': False,
'skip_special_tokens': True,
}
# Input prompt

View File

@ -61,6 +61,7 @@ class Handler(BaseHTTPRequestHandler):
'custom_stopping_strings': body.get('custom_stopping_strings', []),
'truncation_length': int(body.get('truncation_length', 2048)),
'ban_eos_token': bool(body.get('ban_eos_token', False)),
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
}
generator = generate_reply(

View File

@ -4,6 +4,8 @@
groupsize: 'None'
pre_layer: 0
mode: 'cai-chat'
skip_special_tokens: true
custom_stopping_strings: ''
llama-[0-9]*b-4bit$:
wbits: 4
model_type: 'llama'
@ -33,3 +35,10 @@ llama-[0-9]*b-4bit$:
instruction_template: 'Alpaca'
wbits: 4
groupsize: 128
.*(galactica|oasst):
skip_special_tokens: false
.*dolly-v[0-9]-[0-9]*b:
mode: 'instruct'
instruction_template: 'Alpaca'
skip_special_tokens: false
custom_stopping_strings: '"### End"'

View File

@ -41,6 +41,7 @@ settings = {
'stop_at_newline': False,
'add_bos_token': True,
'ban_eos_token': False,
'skip_special_tokens': True,
'truncation_length': 2048,
'truncation_length_min': 0,
'truncation_length_max': 4096,

View File

@ -57,14 +57,13 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
return input_ids.cuda()
def decode(output_ids):
# Open Assistant relies on special tokens like <|endoftext|>
if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
else:
def decode(output_ids, skip_special_tokens=True):
if skip_special_tokens:
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
reply = reply.replace(r'<|endoftext|>', '')
return reply
else:
return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
def generate_softprompt_input_tensors(input_ids):
@ -184,7 +183,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
output = input_ids[0]
if shared.args.verbose:
print(f'\n\n{decode(input_ids[0])}\n--------------------\n')
print(f'\n\n{decode(input_ids[0], state["skip_special_tokens"])}\n--------------------\n')
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
@ -231,11 +230,12 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
output = shared.model.generate(**generate_params)[0]
if cuda:
output = output.cuda()
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:])
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output')
@ -256,18 +256,20 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if not shared.is_chat():
yield formatted_outputs(original_question, shared.model_name)
with generate_with_streaming(**generate_params) as generator:
for output in generator:
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:])
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output')
if output[-1] in eos_token_ids:
break
yield formatted_outputs(reply, shared.model_name)
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
@ -276,18 +278,19 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
clear_torch_cache()
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
new_tokens = len(output) - len(original_input_ids[0])
reply = decode(output[-new_tokens:])
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output')
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break
yield formatted_outputs(reply, shared.model_name)
yield formatted_outputs(reply, shared.model_name)
input_ids = np.reshape(output, (1, output.shape[0]))
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)

View File

@ -25,7 +25,7 @@ def list_model_elements():
def list_interface_input_elements(chat=False):
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings']
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens']
if chat:
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template']
elements += list_model_elements()

View File

@ -424,7 +424,9 @@ def create_settings_menus(default_preset):
with gr.Group():
with gr.Row():
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='This forces the model to never end the generation prematurely.')
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=1, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"')
@ -766,7 +768,7 @@ def create_interface():
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['instruction_template'].change(
lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['instruction_template', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
chat.load_character, [shared.gradio[k] for k in ['instruction_template', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['upload_chat_history'].upload(
@ -784,6 +786,7 @@ def create_interface():
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display'])
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
shared.gradio['interface'].load(chat.load_character, [shared.gradio[k] for k in ['instruction_template', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)

View File

@ -12,6 +12,7 @@
"stop_at_newline": false,
"add_bos_token": true,
"ban_eos_token": false,
"skip_special_tokens": true,
"truncation_length": 2048,
"truncation_length_min": 0,
"truncation_length_max": 4096,