mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Truncate long chat completions inputs (#5439)
This commit is contained in:
parent
9033fa5eee
commit
7073665a10
@ -166,18 +166,53 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||
prompt = remove_extra_bos(prompt)
|
||||
return prompt
|
||||
|
||||
prompt = make_prompt(messages)
|
||||
|
||||
# Handle truncation
|
||||
max_length = get_max_prompt_length(state)
|
||||
while len(messages) > 0 and get_encoded_length(prompt) > max_length:
|
||||
# Try to save the system message
|
||||
if len(messages) > 1 and messages[0]['role'] == 'system':
|
||||
prompt = make_prompt(messages)
|
||||
encoded_length = get_encoded_length(prompt)
|
||||
|
||||
while len(messages) > 0 and encoded_length > max_length:
|
||||
|
||||
# Remove old message, save system message
|
||||
if len(messages) > 2 and messages[0]['role'] == 'system':
|
||||
messages.pop(1)
|
||||
else:
|
||||
|
||||
# Remove old message when no system message is present
|
||||
elif len(messages) > 1 and messages[0]['role'] != 'system':
|
||||
messages.pop(0)
|
||||
|
||||
# Resort to truncating the user input
|
||||
else:
|
||||
|
||||
user_message = messages[-1]['content']
|
||||
|
||||
# Bisect the truncation point
|
||||
left, right = 0, len(user_message) - 1
|
||||
|
||||
while right - left > 1:
|
||||
mid = (left + right) // 2
|
||||
|
||||
messages[-1]['content'] = user_message[mid:]
|
||||
prompt = make_prompt(messages)
|
||||
encoded_length = get_encoded_length(prompt)
|
||||
|
||||
if encoded_length <= max_length:
|
||||
right = mid
|
||||
else:
|
||||
left = mid
|
||||
|
||||
messages[-1]['content'] = user_message[right:]
|
||||
prompt = make_prompt(messages)
|
||||
encoded_length = get_encoded_length(prompt)
|
||||
if encoded_length > max_length:
|
||||
logger.error(f"Failed to build the chat prompt. The input is too long for the available context length.\n\nTruncation length: {state['truncation_length']}\nmax_new_tokens: {state['max_new_tokens']} (is it too high?)\nAvailable context length: {max_length}\n")
|
||||
raise ValueError
|
||||
else:
|
||||
logger.warning(f"The input has been truncated. Context length: {state['truncation_length']}, max_new_tokens: {state['max_new_tokens']}.")
|
||||
break
|
||||
|
||||
prompt = make_prompt(messages)
|
||||
encoded_length = get_encoded_length(prompt)
|
||||
|
||||
if also_return_rows:
|
||||
return prompt, [message['content'] for message in messages]
|
||||
|
@ -50,6 +50,11 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
||||
else:
|
||||
generate_func = generate_reply_HF
|
||||
|
||||
if generate_func != generate_reply_HF and shared.args.verbose:
|
||||
logger.info("PROMPT=")
|
||||
print(question)
|
||||
print()
|
||||
|
||||
# Prepare the input
|
||||
original_question = question
|
||||
if not is_chat:
|
||||
@ -65,10 +70,6 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
||||
if type(st) is list and len(st) > 0:
|
||||
all_stop_strings += st
|
||||
|
||||
if shared.args.verbose:
|
||||
logger.info("PROMPT=")
|
||||
print(question)
|
||||
|
||||
shared.stop_everything = False
|
||||
clear_torch_cache()
|
||||
seed = set_manual_seed(state['seed'])
|
||||
@ -355,6 +356,10 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
||||
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(filtered_params)
|
||||
print()
|
||||
|
||||
logger.info("PROMPT=")
|
||||
print(decode(input_ids[0], skip_special_tokens=False))
|
||||
print()
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
if not is_chat and not shared.is_seq2seq:
|
||||
|
Loading…
Reference in New Issue
Block a user