Truncate long chat completions inputs (#5439)

This commit is contained in:
oobabooga 2024-02-05 02:31:24 -03:00 committed by GitHub
parent 9033fa5eee
commit 7073665a10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 10 deletions

View File

@ -166,18 +166,53 @@ def generate_chat_prompt(user_input, state, **kwargs):
prompt = remove_extra_bos(prompt)
return prompt
prompt = make_prompt(messages)
# Handle truncation
max_length = get_max_prompt_length(state)
while len(messages) > 0 and get_encoded_length(prompt) > max_length:
# Try to save the system message
if len(messages) > 1 and messages[0]['role'] == 'system':
prompt = make_prompt(messages)
encoded_length = get_encoded_length(prompt)
while len(messages) > 0 and encoded_length > max_length:
# Remove old message, save system message
if len(messages) > 2 and messages[0]['role'] == 'system':
messages.pop(1)
else:
# Remove old message when no system message is present
elif len(messages) > 1 and messages[0]['role'] != 'system':
messages.pop(0)
# Resort to truncating the user input
else:
user_message = messages[-1]['content']
# Bisect the truncation point
left, right = 0, len(user_message) - 1
while right - left > 1:
mid = (left + right) // 2
messages[-1]['content'] = user_message[mid:]
prompt = make_prompt(messages)
encoded_length = get_encoded_length(prompt)
if encoded_length <= max_length:
right = mid
else:
left = mid
messages[-1]['content'] = user_message[right:]
prompt = make_prompt(messages)
encoded_length = get_encoded_length(prompt)
if encoded_length > max_length:
logger.error(f"Failed to build the chat prompt. The input is too long for the available context length.\n\nTruncation length: {state['truncation_length']}\nmax_new_tokens: {state['max_new_tokens']} (is it too high?)\nAvailable context length: {max_length}\n")
raise ValueError
else:
logger.warning(f"The input has been truncated. Context length: {state['truncation_length']}, max_new_tokens: {state['max_new_tokens']}.")
break
prompt = make_prompt(messages)
encoded_length = get_encoded_length(prompt)
if also_return_rows:
return prompt, [message['content'] for message in messages]

View File

@ -50,6 +50,11 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
else:
generate_func = generate_reply_HF
if generate_func != generate_reply_HF and shared.args.verbose:
logger.info("PROMPT=")
print(question)
print()
# Prepare the input
original_question = question
if not is_chat:
@ -65,10 +70,6 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
if type(st) is list and len(st) > 0:
all_stop_strings += st
if shared.args.verbose:
logger.info("PROMPT=")
print(question)
shared.stop_everything = False
clear_torch_cache()
seed = set_manual_seed(state['seed'])
@ -355,6 +356,10 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(filtered_params)
print()
logger.info("PROMPT=")
print(decode(input_ids[0], skip_special_tokens=False))
print()
t0 = time.time()
try:
if not is_chat and not shared.is_seq2seq: