mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Fix character loading bug
This commit is contained in:
parent
3d9a499a02
commit
ac6065d5ed
11
server.py
11
server.py
@ -183,9 +183,9 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
|||||||
cuda = "" if args.cpu else ".cuda()"
|
cuda = "" if args.cpu else ".cuda()"
|
||||||
n = tokenizer.eos_token_id if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
|
n = tokenizer.eos_token_id if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
|
||||||
input_ids = encode(question, tokens)
|
input_ids = encode(question, tokens)
|
||||||
# The stopping_criteria code below was copied from
|
|
||||||
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
|
|
||||||
if stopping_string is not None:
|
if stopping_string is not None:
|
||||||
|
# The stopping_criteria code below was copied from
|
||||||
|
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
|
||||||
t = encode(stopping_string, 0, add_special_tokens=False)
|
t = encode(stopping_string, 0, add_special_tokens=False)
|
||||||
stopping_criteria_list = transformers.StoppingCriteriaList([
|
stopping_criteria_list = transformers.StoppingCriteriaList([
|
||||||
_SentinelTokenStoppingCriteria(
|
_SentinelTokenStoppingCriteria(
|
||||||
@ -382,16 +382,19 @@ if args.chat or args.cai_chat:
|
|||||||
return generate_chat_html(_history, name1, name2, character)
|
return generate_chat_html(_history, name1, name2, character)
|
||||||
|
|
||||||
def tokenize_dialogue(dialogue, name1, name2):
|
def tokenize_dialogue(dialogue, name1, name2):
|
||||||
|
history = []
|
||||||
|
|
||||||
dialogue = re.sub('<START>', '', dialogue)
|
dialogue = re.sub('<START>', '', dialogue)
|
||||||
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
|
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
|
||||||
|
|
||||||
idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)]
|
idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)]
|
||||||
|
if len(idx) == 0:
|
||||||
|
return history
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
for i in range(len(idx)-1):
|
for i in range(len(idx)-1):
|
||||||
messages.append(dialogue[idx[i]:idx[i+1]].strip())
|
messages.append(dialogue[idx[i]:idx[i+1]].strip())
|
||||||
messages.append(dialogue[idx[-1]:].strip())
|
messages.append(dialogue[idx[-1]:].strip())
|
||||||
|
|
||||||
history = []
|
|
||||||
entry = ['', '']
|
entry = ['', '']
|
||||||
for i in messages:
|
for i in messages:
|
||||||
if i.startswith(f'{name1}:'):
|
if i.startswith(f'{name1}:'):
|
||||||
|
Loading…
Reference in New Issue
Block a user