Better way of finding the generated reply in the output string

This commit is contained in:
oobabooga 2023-01-19 14:57:01 -03:00
parent d03b0ad7a8
commit 849e4c7f90

View File

@ -136,14 +136,17 @@ def decode(output_ids):
return reply
def formatted_outputs(reply, model_name):
if model_name.lower().startswith('galactica'):
reply = fix_galactica(reply)
return reply, reply, generate_basic_html(reply)
elif model_name.lower().startswith('gpt4chan'):
reply = fix_gpt4chan(reply)
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
if not (args.chat or args.cai_chat):
if model_name.lower().startswith('galactica'):
reply = fix_galactica(reply)
return reply, reply, generate_basic_html(reply)
elif model_name.lower().startswith('gpt4chan'):
reply = fix_gpt4chan(reply)
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
else:
return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
else:
return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
return reply
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None):
global model, tokenizer, model_name, loaded_preset, preset
@ -245,16 +248,17 @@ if args.chat or args.cai_chat:
question = generate_chat_prompt(text, tokens, name1, name2, context)
history.append(['', ''])
eos_token = '\n' if check else None
for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
reply = i[0]
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
next_character_found = False
previous_idx = [m.start() for m in re.finditer(f"\n{name2}:", question)]
idx = [m.start() for m in re.finditer(f"(^|\n){name2}:", reply)]
idx = idx[len(previous_idx)-1]
reply = reply[idx + len(f"\n{name2}:"):]
if check:
idx = reply.rfind(question[-1024:])
reply = reply[idx+min(1024, len(question)):].split('\n')[0].strip()
reply = reply.split('\n')[0].strip()
else:
idx = reply.rfind(question[-1024:])
reply = reply[idx+min(1024, len(question)):]
idx = reply.find(f"\n{name1}:")
if idx != -1:
reply = reply[:idx]