Better handle spaces in LlamaTokenizer

This commit is contained in:
oobabooga 2023-05-11 17:55:50 -03:00
parent 7221d1389a
commit 71693161eb

View File

@ -107,8 +107,10 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state, i
else:
new_tokens = len(output_ids) - len(input_ids[0])
reply = decode(output_ids[-new_tokens:], state['skip_special_tokens'])
if type(shared.tokenizer) is transformers.LlamaTokenizer:
if len(original_question) > 0 and original_question[-1] not in [' ', '\n']:
# Prevent LlamaTokenizer from skipping a space
if type(shared.tokenizer) is transformers.LlamaTokenizer and len(output_ids) > 0:
if shared.tokenizer.convert_ids_to_tokens(int(output_ids[-new_tokens])).startswith(''):
reply = ' ' + reply
if not is_chat: