From 6430acadde543822cb7b17b5fba3bab1af682558 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 5 Dec 2023 10:05:54 -0800 Subject: [PATCH] Minor bug fix after https://github.com/oobabooga/text-generation-webui/pull/4814 --- modules/text_generation.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index ca379fd7..4cf4f720 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -264,14 +264,10 @@ def apply_stopping_strings(reply, all_stop_strings): def get_reply_from_output_ids(output_ids, state, starting_from=0): - if shared.is_seq2seq: - reply = decode(output_ids, state['skip_special_tokens']) - else: - reply = decode(output_ids[starting_from:], state['skip_special_tokens']) - # Prevent LlamaTokenizer from skipping a space - if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > 0: - if shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])).startswith('▁'): - reply = ' ' + reply + reply = decode(output_ids[starting_from:], state['skip_special_tokens']) + if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > starting_from: + if shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])).startswith('▁'): + reply = ' ' + reply return reply @@ -343,7 +339,8 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings if cuda: output = output.cuda() - yield get_reply_from_output_ids(output, state, starting_from=len(input_ids[0])) + starting_from = 0 if shared.is_seq2seq else len(input_ids[0]) + yield get_reply_from_output_ids(output, state, starting_from=starting_from) # Stream the reply 1 token at a time. # This is based on the trick of using 'stopping_criteria' to create an iterator. @@ -360,7 +357,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings with generate_with_streaming(**generate_params) as generator: cumulative_reply = '' - starting_from = len(input_ids[0]) + starting_from = 0 if shared.is_seq2seq else len(input_ids[0]) for output in generator: if output[-1] in eos_token_ids: break