Fix character loading bug

This commit is contained in:
oobabooga 2023-01-26 13:45:19 -03:00
parent 3d9a499a02
commit ac6065d5ed

View File

@ -183,9 +183,9 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
cuda = "" if args.cpu else ".cuda()" cuda = "" if args.cpu else ".cuda()"
n = tokenizer.eos_token_id if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1] n = tokenizer.eos_token_id if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
input_ids = encode(question, tokens) input_ids = encode(question, tokens)
# The stopping_criteria code below was copied from
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
if stopping_string is not None: if stopping_string is not None:
# The stopping_criteria code below was copied from
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
t = encode(stopping_string, 0, add_special_tokens=False) t = encode(stopping_string, 0, add_special_tokens=False)
stopping_criteria_list = transformers.StoppingCriteriaList([ stopping_criteria_list = transformers.StoppingCriteriaList([
_SentinelTokenStoppingCriteria( _SentinelTokenStoppingCriteria(
@ -382,16 +382,19 @@ if args.chat or args.cai_chat:
return generate_chat_html(_history, name1, name2, character) return generate_chat_html(_history, name1, name2, character)
def tokenize_dialogue(dialogue, name1, name2): def tokenize_dialogue(dialogue, name1, name2):
history = []
dialogue = re.sub('<START>', '', dialogue) dialogue = re.sub('<START>', '', dialogue)
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue) dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)] idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)]
if len(idx) == 0:
return history
messages = [] messages = []
for i in range(len(idx)-1): for i in range(len(idx)-1):
messages.append(dialogue[idx[i]:idx[i+1]].strip()) messages.append(dialogue[idx[i]:idx[i+1]].strip())
messages.append(dialogue[idx[-1]:].strip()) messages.append(dialogue[idx[-1]:].strip())
history = []
entry = ['', ''] entry = ['', '']
for i in messages: for i in messages:
if i.startswith(f'{name1}:'): if i.startswith(f'{name1}:'):