From 849e4c7f900cc47d6105e97eb87fe980319b37ce Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 19 Jan 2023 14:57:01 -0300 Subject: [PATCH] Better way of finding the generated reply in the output string --- server.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/server.py b/server.py index d9a26265..b5121f63 100644 --- a/server.py +++ b/server.py @@ -136,14 +136,17 @@ def decode(output_ids): return reply def formatted_outputs(reply, model_name): - if model_name.lower().startswith('galactica'): - reply = fix_galactica(reply) - return reply, reply, generate_basic_html(reply) - elif model_name.lower().startswith('gpt4chan'): - reply = fix_gpt4chan(reply) - return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) + if not (args.chat or args.cai_chat): + if model_name.lower().startswith('galactica'): + reply = fix_galactica(reply) + return reply, reply, generate_basic_html(reply) + elif model_name.lower().startswith('gpt4chan'): + reply = fix_gpt4chan(reply) + return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) + else: + return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply) else: - return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply) + return reply def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None): global model, tokenizer, model_name, loaded_preset, preset @@ -245,16 +248,17 @@ if args.chat or args.cai_chat: question = generate_chat_prompt(text, tokens, name1, name2, context) history.append(['', '']) eos_token = '\n' if check else None - for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token): - reply = i[0] + for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token): next_character_found = False + previous_idx = [m.start() for m in re.finditer(f"\n{name2}:", question)] + idx = [m.start() for m in re.finditer(f"(^|\n){name2}:", reply)] + idx = idx[len(previous_idx)-1] + reply = reply[idx + len(f"\n{name2}:"):] + if check: - idx = reply.rfind(question[-1024:]) - reply = reply[idx+min(1024, len(question)):].split('\n')[0].strip() + reply = reply.split('\n')[0].strip() else: - idx = reply.rfind(question[-1024:]) - reply = reply[idx+min(1024, len(question)):] idx = reply.find(f"\n{name1}:") if idx != -1: reply = reply[:idx]