mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Add support for extensions
This is experimental.
This commit is contained in:
parent
414fa9d161
commit
6b5dcd46c5
@ -133,6 +133,7 @@ Optionally, you can use the following command-line flags:
|
|||||||
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.|
|
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.|
|
||||||
| `--no-stream` | Don't stream the text output in real time. This improves the text generation performance.|
|
| `--no-stream` | Don't stream the text output in real time. This improves the text generation performance.|
|
||||||
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.|
|
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.|
|
||||||
|
| `--extensions EXTENSIONS` | The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this". |
|
||||||
| `--listen` | Make the web UI reachable from your local network.|
|
| `--listen` | Make the web UI reachable from your local network.|
|
||||||
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
|
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
|
||||||
| `--verbose` | Print the prompts to the terminal. |
|
| `--verbose` | Print the prompts to the terminal. |
|
||||||
|
14
extensions/example/script.py
Normal file
14
extensions/example/script.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
def input_modifier(string):
|
||||||
|
"""
|
||||||
|
This function is applied to your text inputs before
|
||||||
|
they are fed into the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return string.replace(' ', '#')
|
||||||
|
|
||||||
|
def output_modifier(string):
|
||||||
|
"""
|
||||||
|
This function is applied to the model outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return string.replace(' ', '_')
|
141
server.py
141
server.py
@ -5,6 +5,7 @@ import glob
|
|||||||
import torch
|
import torch
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
from sys import exit
|
from sys import exit
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -32,6 +33,7 @@ parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to
|
|||||||
parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.')
|
parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.')
|
||||||
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
|
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
|
||||||
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
|
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
|
||||||
|
parser.add_argument('--extensions', type=str, help='The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this".')
|
||||||
parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
|
parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
|
||||||
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
|
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
|
||||||
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
||||||
@ -165,6 +167,9 @@ def formatted_outputs(reply, model_name):
|
|||||||
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None, stopping_string=None):
|
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None, stopping_string=None):
|
||||||
global model, tokenizer, model_name, loaded_preset, preset
|
global model, tokenizer, model_name, loaded_preset, preset
|
||||||
|
|
||||||
|
original_question = question
|
||||||
|
if not (args.chat or args.cai_chat):
|
||||||
|
question = apply_extensions(question, "input")
|
||||||
if args.verbose:
|
if args.verbose:
|
||||||
print(f"\n\n{question}\n--------------------\n")
|
print(f"\n\n{question}\n--------------------\n")
|
||||||
|
|
||||||
@ -203,20 +208,36 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
|||||||
reply = decode(output[0])
|
reply = decode(output[0])
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)")
|
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)")
|
||||||
|
if not (args.chat or args.cai_chat):
|
||||||
|
reply = original_question + apply_extensions(reply[len(question):], "output")
|
||||||
yield formatted_outputs(reply, model_name)
|
yield formatted_outputs(reply, model_name)
|
||||||
|
|
||||||
# Generate the reply 1 token at a time
|
# Generate the reply 1 token at a time
|
||||||
else:
|
else:
|
||||||
yield formatted_outputs(question, model_name)
|
yield formatted_outputs(original_question, model_name)
|
||||||
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=8')
|
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=8')
|
||||||
for i in tqdm(range(tokens//8+1)):
|
for i in tqdm(range(tokens//8+1)):
|
||||||
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
|
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
|
||||||
reply = decode(output[0])
|
reply = decode(output[0])
|
||||||
|
if not (args.chat or args.cai_chat):
|
||||||
|
reply = original_question + apply_extensions(reply[len(question):], "output")
|
||||||
yield formatted_outputs(reply, model_name)
|
yield formatted_outputs(reply, model_name)
|
||||||
input_ids = output
|
input_ids = output
|
||||||
if output[0][-1] == n:
|
if output[0][-1] == n:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def apply_extensions(text, typ):
|
||||||
|
global available_extensions, extension_state
|
||||||
|
for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
|
||||||
|
if extension_state[ext][0] == True:
|
||||||
|
ext_string = f"extensions.{ext}.script"
|
||||||
|
exec(f"import {ext_string}")
|
||||||
|
if typ == "input":
|
||||||
|
text = eval(f"{ext_string}.input_modifier(text)")
|
||||||
|
else:
|
||||||
|
text = eval(f"{ext_string}.output_modifier(text)")
|
||||||
|
return text
|
||||||
|
|
||||||
def get_available_models():
|
def get_available_models():
|
||||||
return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
|
return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
|
||||||
|
|
||||||
@ -226,9 +247,19 @@ def get_available_presets():
|
|||||||
def get_available_characters():
|
def get_available_characters():
|
||||||
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
|
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
|
||||||
|
|
||||||
|
def get_available_extensions():
|
||||||
|
return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
|
||||||
|
|
||||||
available_models = get_available_models()
|
available_models = get_available_models()
|
||||||
available_presets = get_available_presets()
|
available_presets = get_available_presets()
|
||||||
available_characters = get_available_characters()
|
available_characters = get_available_characters()
|
||||||
|
available_extensions = get_available_extensions()
|
||||||
|
extension_state = {}
|
||||||
|
if args.extensions is not None:
|
||||||
|
for i,ext in enumerate(args.extensions.split(',')):
|
||||||
|
if ext in available_extensions:
|
||||||
|
print(f'The extension "{ext}" is enabled.')
|
||||||
|
extension_state[ext] = [True, i]
|
||||||
|
|
||||||
# Choosing the default model
|
# Choosing the default model
|
||||||
if args.model is not None:
|
if args.model is not None:
|
||||||
@ -256,7 +287,7 @@ description = f"\n\n# Text generation lab\nGenerate text using Large Language Mo
|
|||||||
css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}"
|
css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}"
|
||||||
|
|
||||||
if args.chat or args.cai_chat:
|
if args.chat or args.cai_chat:
|
||||||
history = []
|
history = {'internal': [], 'visible': []}
|
||||||
character = None
|
character = None
|
||||||
|
|
||||||
# This gets the new line characters right.
|
# This gets the new line characters right.
|
||||||
@ -270,13 +301,13 @@ if args.chat or args.cai_chat:
|
|||||||
text = clean_chat_message(text)
|
text = clean_chat_message(text)
|
||||||
|
|
||||||
rows = [f"{context.strip()}\n"]
|
rows = [f"{context.strip()}\n"]
|
||||||
i = len(history)-1
|
i = len(history['internal'])-1
|
||||||
count = 0
|
count = 0
|
||||||
while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens:
|
while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens:
|
||||||
rows.insert(1, f"{name2}: {history[i][1].strip()}\n")
|
rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n")
|
||||||
count += 1
|
count += 1
|
||||||
if not (history[i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
|
if not (history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
|
||||||
rows.insert(1, f"{name1}: {history[i][0].strip()}\n")
|
rows.insert(1, f"{name1}: {history['internal'][i][0].strip()}\n")
|
||||||
count += 1
|
count += 1
|
||||||
i -= 1
|
i -= 1
|
||||||
if history_size != 0 and count >= history_size:
|
if history_size != 0 and count >= history_size:
|
||||||
@ -291,18 +322,12 @@ if args.chat or args.cai_chat:
|
|||||||
question = ''.join(rows)
|
question = ''.join(rows)
|
||||||
return question
|
return question
|
||||||
|
|
||||||
def remove_example_dialogue_from_history(history):
|
|
||||||
_history = copy.deepcopy(history)
|
|
||||||
for i in range(len(_history)):
|
|
||||||
if '<|BEGIN-VISIBLE-CHAT|>' in _history[i][0]:
|
|
||||||
_history[i][0] = _history[i][0].replace('<|BEGIN-VISIBLE-CHAT|>', '')
|
|
||||||
_history = _history[i:]
|
|
||||||
break
|
|
||||||
return _history
|
|
||||||
|
|
||||||
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
||||||
|
original_text = text
|
||||||
|
text = apply_extensions(text, "input")
|
||||||
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
|
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
|
||||||
history.append(['', ''])
|
history['internal'].append(['', ''])
|
||||||
|
history['visible'].append(['', ''])
|
||||||
eos_token = '\n' if check else None
|
eos_token = '\n' if check else None
|
||||||
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name1}:"):
|
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name1}:"):
|
||||||
next_character_found = False
|
next_character_found = False
|
||||||
@ -312,7 +337,6 @@ if args.chat or args.cai_chat:
|
|||||||
idx = idx[len(previous_idx)-1]
|
idx = idx[len(previous_idx)-1]
|
||||||
|
|
||||||
reply = reply[idx + len(f"\n{name2}:"):]
|
reply = reply[idx + len(f"\n{name2}:"):]
|
||||||
|
|
||||||
if check:
|
if check:
|
||||||
reply = reply.split('\n')[0].strip()
|
reply = reply.split('\n')[0].strip()
|
||||||
else:
|
else:
|
||||||
@ -322,7 +346,8 @@ if args.chat or args.cai_chat:
|
|||||||
next_character_found = True
|
next_character_found = True
|
||||||
reply = clean_chat_message(reply)
|
reply = clean_chat_message(reply)
|
||||||
|
|
||||||
history[-1] = [text, reply]
|
history['internal'][-1] = [text, reply]
|
||||||
|
history['visible'][-1] = [original_text, apply_extensions(reply, "output")]
|
||||||
if next_character_found:
|
if next_character_found:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -335,16 +360,17 @@ if args.chat or args.cai_chat:
|
|||||||
next_character_substring_found = True
|
next_character_substring_found = True
|
||||||
|
|
||||||
if not next_character_substring_found:
|
if not next_character_substring_found:
|
||||||
yield remove_example_dialogue_from_history(history)
|
yield history['visible']
|
||||||
|
|
||||||
yield remove_example_dialogue_from_history(history)
|
yield history['visible']
|
||||||
|
|
||||||
def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
||||||
for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
for _history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
||||||
yield generate_chat_html(history, name1, name2, character)
|
yield generate_chat_html(_history, name1, name2, character)
|
||||||
|
|
||||||
def regenerate_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
def regenerate_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
||||||
last = history.pop()
|
last = history['internal'].pop()
|
||||||
|
history['visible'].pop()
|
||||||
text = last[0]
|
text = last[0]
|
||||||
if args.cai_chat:
|
if args.cai_chat:
|
||||||
for i in cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
for i in cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
||||||
@ -354,12 +380,15 @@ if args.chat or args.cai_chat:
|
|||||||
yield i
|
yield i
|
||||||
|
|
||||||
def remove_last_message(name1, name2):
|
def remove_last_message(name1, name2):
|
||||||
last = history.pop()
|
if not history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
|
||||||
_history = remove_example_dialogue_from_history(history)
|
last = history['visible'].pop()
|
||||||
if args.cai_chat:
|
history['internal'].pop()
|
||||||
return generate_chat_html(_history, name1, name2, character), last[0]
|
|
||||||
else:
|
else:
|
||||||
return _history, last[0]
|
last = ['', '']
|
||||||
|
if args.cai_chat:
|
||||||
|
return generate_chat_html(history['visible'], name1, name2, character), last[0]
|
||||||
|
else:
|
||||||
|
return history['visible'], last[0]
|
||||||
|
|
||||||
def clear_html():
|
def clear_html():
|
||||||
return generate_chat_html([], "", "", character)
|
return generate_chat_html([], "", "", character)
|
||||||
@ -367,28 +396,31 @@ if args.chat or args.cai_chat:
|
|||||||
def clear_chat_log(_character, name1, name2):
|
def clear_chat_log(_character, name1, name2):
|
||||||
global history
|
global history
|
||||||
if _character != 'None':
|
if _character != 'None':
|
||||||
load_character(_character, name1, name2)
|
for i in range(len(history['internal'])):
|
||||||
|
if '<|BEGIN-VISIBLE-CHAT|>' in history['internal'][i][0]:
|
||||||
|
history['visible'] = [['', history['internal'][i][1]]]
|
||||||
|
history['internal'] = history['internal'][:i+1]
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
history = []
|
history['internal'] = []
|
||||||
_history = remove_example_dialogue_from_history(history)
|
history['visible'] = []
|
||||||
if args.cai_chat:
|
if args.cai_chat:
|
||||||
return generate_chat_html(_history, name1, name2, character)
|
return generate_chat_html(history['visible'], name1, name2, character)
|
||||||
else:
|
else:
|
||||||
return _history
|
return history['visible']
|
||||||
|
|
||||||
def redraw_html(name1, name2):
|
def redraw_html(name1, name2):
|
||||||
global history
|
global history
|
||||||
_history = remove_example_dialogue_from_history(history)
|
return generate_chat_html(history['visible'], name1, name2, character)
|
||||||
return generate_chat_html(_history, name1, name2, character)
|
|
||||||
|
|
||||||
def tokenize_dialogue(dialogue, name1, name2):
|
def tokenize_dialogue(dialogue, name1, name2):
|
||||||
history = []
|
_history = []
|
||||||
|
|
||||||
dialogue = re.sub('<START>', '', dialogue)
|
dialogue = re.sub('<START>', '', dialogue)
|
||||||
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
|
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
|
||||||
idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)]
|
idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)]
|
||||||
if len(idx) == 0:
|
if len(idx) == 0:
|
||||||
return history
|
return _history
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
for i in range(len(idx)-1):
|
for i in range(len(idx)-1):
|
||||||
@ -402,16 +434,16 @@ if args.chat or args.cai_chat:
|
|||||||
elif i.startswith(f'{name2}:'):
|
elif i.startswith(f'{name2}:'):
|
||||||
entry[1] = i[len(f'{name2}:'):].strip()
|
entry[1] = i[len(f'{name2}:'):].strip()
|
||||||
if not (len(entry[0]) == 0 and len(entry[1]) == 0):
|
if not (len(entry[0]) == 0 and len(entry[1]) == 0):
|
||||||
history.append(entry)
|
_history.append(entry)
|
||||||
entry = ['', '']
|
entry = ['', '']
|
||||||
|
|
||||||
return history
|
return _history
|
||||||
|
|
||||||
def save_history():
|
def save_history():
|
||||||
if not Path('logs').exists():
|
if not Path('logs').exists():
|
||||||
Path('logs').mkdir()
|
Path('logs').mkdir()
|
||||||
with open(Path('logs/conversation.json'), 'w') as f:
|
with open(Path('logs/conversation.json'), 'w') as f:
|
||||||
f.write(json.dumps({'data': history}, indent=2))
|
f.write(json.dumps({'data': history['internal']}, indent=2))
|
||||||
return Path('logs/conversation.json')
|
return Path('logs/conversation.json')
|
||||||
|
|
||||||
def upload_history(file, name1, name2):
|
def upload_history(file, name1, name2):
|
||||||
@ -420,21 +452,22 @@ if args.chat or args.cai_chat:
|
|||||||
try:
|
try:
|
||||||
j = json.loads(file)
|
j = json.loads(file)
|
||||||
if 'data' in j:
|
if 'data' in j:
|
||||||
history = j['data']
|
history['internal'] = j['data']
|
||||||
# Compatibility with Pygmalion AI's official web UI
|
# Compatibility with Pygmalion AI's official web UI
|
||||||
elif 'chat' in j:
|
elif 'chat' in j:
|
||||||
history = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
|
history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
|
||||||
if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
|
if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
|
||||||
history = [['<|BEGIN-VISIBLE-CHAT|>', history[0]]] + [[history[i], history[i+1]] for i in range(1, len(history)-1, 2)]
|
history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', history['internal'][0]]] + [[history['internal'][i], history['internal'][i+1]] for i in range(1, len(history['internal'])-1, 2)]
|
||||||
else:
|
else:
|
||||||
history = [[history[i], history[i+1]] for i in range(0, len(history)-1, 2)]
|
history['internal'] = [[history['internal'][i], history['internal'][i+1]] for i in range(0, len(history['internal'])-1, 2)]
|
||||||
except:
|
except:
|
||||||
history = tokenize_dialogue(file, name1, name2)
|
history['internal'] = tokenize_dialogue(file, name1, name2)
|
||||||
|
|
||||||
def load_character(_character, name1, name2):
|
def load_character(_character, name1, name2):
|
||||||
global history, character
|
global history, character
|
||||||
context = ""
|
context = ""
|
||||||
history = []
|
history['internal'] = []
|
||||||
|
history['visible'] = []
|
||||||
if _character != 'None':
|
if _character != 'None':
|
||||||
character = _character
|
character = _character
|
||||||
with open(Path(f'characters/{_character}.json'), 'r') as f:
|
with open(Path(f'characters/{_character}.json'), 'r') as f:
|
||||||
@ -446,24 +479,24 @@ if args.chat or args.cai_chat:
|
|||||||
context += f"Scenario: {data['world_scenario']}\n"
|
context += f"Scenario: {data['world_scenario']}\n"
|
||||||
context = f"{context.strip()}\n<START>\n"
|
context = f"{context.strip()}\n<START>\n"
|
||||||
if 'example_dialogue' in data and data['example_dialogue'] != '':
|
if 'example_dialogue' in data and data['example_dialogue'] != '':
|
||||||
history = tokenize_dialogue(data['example_dialogue'], name1, name2)
|
history['internal'] = tokenize_dialogue(data['example_dialogue'], name1, name2)
|
||||||
if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0:
|
if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0:
|
||||||
history += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
|
history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
|
||||||
|
history['visible'] += [['', data['char_greeting']]]
|
||||||
else:
|
else:
|
||||||
history += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
|
history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
|
||||||
|
history['visible'] += [['', "Hello there!"]]
|
||||||
else:
|
else:
|
||||||
character = None
|
character = None
|
||||||
context = settings['context_pygmalion']
|
context = settings['context_pygmalion']
|
||||||
name2 = settings['name2_pygmalion']
|
name2 = settings['name2_pygmalion']
|
||||||
|
|
||||||
_history = remove_example_dialogue_from_history(history)
|
|
||||||
if args.cai_chat:
|
if args.cai_chat:
|
||||||
return name2, context, generate_chat_html(_history, name1, name2, character)
|
return name2, context, generate_chat_html(history['visible'], name1, name2, character)
|
||||||
else:
|
else:
|
||||||
return name2, context, _history
|
return name2, context, history['visible']
|
||||||
|
|
||||||
def upload_character(file, name1, name2):
|
def upload_character(file, name1, name2):
|
||||||
global history
|
|
||||||
file = file.decode('utf-8')
|
file = file.decode('utf-8')
|
||||||
data = json.loads(file)
|
data = json.loads(file)
|
||||||
outfile_name = data["char_name"]
|
outfile_name = data["char_name"]
|
||||||
@ -543,7 +576,7 @@ if args.chat or args.cai_chat:
|
|||||||
if args.cai_chat:
|
if args.cai_chat:
|
||||||
upload.upload(redraw_html, [name1, name2], [display1])
|
upload.upload(redraw_html, [name1, name2], [display1])
|
||||||
else:
|
else:
|
||||||
upload.upload(lambda : remove_example_dialogue_from_history(history), [], [display1])
|
upload.upload(lambda : history['visible'], [], [display1])
|
||||||
|
|
||||||
elif args.notebook:
|
elif args.notebook:
|
||||||
with gr.Blocks(css=css, analytics_enabled=False) as interface:
|
with gr.Blocks(css=css, analytics_enabled=False) as interface:
|
||||||
|
Loading…
Reference in New Issue
Block a user