Add support for extensions

This is experimental.
This commit is contained in:
oobabooga 2023-01-27 00:40:39 -03:00
parent 414fa9d161
commit 6b5dcd46c5
3 changed files with 102 additions and 54 deletions

View File

@ -133,6 +133,7 @@ Optionally, you can use the following command-line flags:
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.|
| `--no-stream` | Don't stream the text output in real time. This improves the text generation performance.|
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.|
| `--extensions EXTENSIONS` | The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this". |
| `--listen` | Make the web UI reachable from your local network.|
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
| `--verbose` | Print the prompts to the terminal. |

View File

@ -0,0 +1,14 @@
def input_modifier(string):
"""
This function is applied to your text inputs before
they are fed into the model.
"""
return string.replace(' ', '#')
def output_modifier(string):
"""
This function is applied to the model outputs.
"""
return string.replace(' ', '_')

141
server.py
View File

@ -5,6 +5,7 @@ import glob
import torch
import argparse
import json
import sys
from sys import exit
from pathlib import Path
import gradio as gr
@ -32,6 +33,7 @@ parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to
parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.')
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
parser.add_argument('--extensions', type=str, help='The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this".')
parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
@ -165,6 +167,9 @@ def formatted_outputs(reply, model_name):
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None, stopping_string=None):
global model, tokenizer, model_name, loaded_preset, preset
original_question = question
if not (args.chat or args.cai_chat):
question = apply_extensions(question, "input")
if args.verbose:
print(f"\n\n{question}\n--------------------\n")
@ -203,20 +208,36 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
reply = decode(output[0])
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)")
if not (args.chat or args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, model_name)
# Generate the reply 1 token at a time
else:
yield formatted_outputs(question, model_name)
yield formatted_outputs(original_question, model_name)
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=8')
for i in tqdm(range(tokens//8+1)):
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
reply = decode(output[0])
if not (args.chat or args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, model_name)
input_ids = output
if output[0][-1] == n:
break
def apply_extensions(text, typ):
global available_extensions, extension_state
for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
if extension_state[ext][0] == True:
ext_string = f"extensions.{ext}.script"
exec(f"import {ext_string}")
if typ == "input":
text = eval(f"{ext_string}.input_modifier(text)")
else:
text = eval(f"{ext_string}.output_modifier(text)")
return text
def get_available_models():
return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
@ -226,9 +247,19 @@ def get_available_presets():
def get_available_characters():
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
def get_available_extensions():
return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
available_models = get_available_models()
available_presets = get_available_presets()
available_characters = get_available_characters()
available_extensions = get_available_extensions()
extension_state = {}
if args.extensions is not None:
for i,ext in enumerate(args.extensions.split(',')):
if ext in available_extensions:
print(f'The extension "{ext}" is enabled.')
extension_state[ext] = [True, i]
# Choosing the default model
if args.model is not None:
@ -256,7 +287,7 @@ description = f"\n\n# Text generation lab\nGenerate text using Large Language Mo
css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}"
if args.chat or args.cai_chat:
history = []
history = {'internal': [], 'visible': []}
character = None
# This gets the new line characters right.
@ -270,13 +301,13 @@ if args.chat or args.cai_chat:
text = clean_chat_message(text)
rows = [f"{context.strip()}\n"]
i = len(history)-1
i = len(history['internal'])-1
count = 0
while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens:
rows.insert(1, f"{name2}: {history[i][1].strip()}\n")
rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n")
count += 1
if not (history[i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
rows.insert(1, f"{name1}: {history[i][0].strip()}\n")
if not (history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
rows.insert(1, f"{name1}: {history['internal'][i][0].strip()}\n")
count += 1
i -= 1
if history_size != 0 and count >= history_size:
@ -291,18 +322,12 @@ if args.chat or args.cai_chat:
question = ''.join(rows)
return question
def remove_example_dialogue_from_history(history):
_history = copy.deepcopy(history)
for i in range(len(_history)):
if '<|BEGIN-VISIBLE-CHAT|>' in _history[i][0]:
_history[i][0] = _history[i][0].replace('<|BEGIN-VISIBLE-CHAT|>', '')
_history = _history[i:]
break
return _history
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
original_text = text
text = apply_extensions(text, "input")
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
history.append(['', ''])
history['internal'].append(['', ''])
history['visible'].append(['', ''])
eos_token = '\n' if check else None
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name1}:"):
next_character_found = False
@ -312,7 +337,6 @@ if args.chat or args.cai_chat:
idx = idx[len(previous_idx)-1]
reply = reply[idx + len(f"\n{name2}:"):]
if check:
reply = reply.split('\n')[0].strip()
else:
@ -322,7 +346,8 @@ if args.chat or args.cai_chat:
next_character_found = True
reply = clean_chat_message(reply)
history[-1] = [text, reply]
history['internal'][-1] = [text, reply]
history['visible'][-1] = [original_text, apply_extensions(reply, "output")]
if next_character_found:
break
@ -335,16 +360,17 @@ if args.chat or args.cai_chat:
next_character_substring_found = True
if not next_character_substring_found:
yield remove_example_dialogue_from_history(history)
yield history['visible']
yield remove_example_dialogue_from_history(history)
yield history['visible']
def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
yield generate_chat_html(history, name1, name2, character)
for _history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
yield generate_chat_html(_history, name1, name2, character)
def regenerate_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
last = history.pop()
last = history['internal'].pop()
history['visible'].pop()
text = last[0]
if args.cai_chat:
for i in cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
@ -354,12 +380,15 @@ if args.chat or args.cai_chat:
yield i
def remove_last_message(name1, name2):
last = history.pop()
_history = remove_example_dialogue_from_history(history)
if args.cai_chat:
return generate_chat_html(_history, name1, name2, character), last[0]
if not history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
last = history['visible'].pop()
history['internal'].pop()
else:
return _history, last[0]
last = ['', '']
if args.cai_chat:
return generate_chat_html(history['visible'], name1, name2, character), last[0]
else:
return history['visible'], last[0]
def clear_html():
return generate_chat_html([], "", "", character)
@ -367,28 +396,31 @@ if args.chat or args.cai_chat:
def clear_chat_log(_character, name1, name2):
global history
if _character != 'None':
load_character(_character, name1, name2)
for i in range(len(history['internal'])):
if '<|BEGIN-VISIBLE-CHAT|>' in history['internal'][i][0]:
history['visible'] = [['', history['internal'][i][1]]]
history['internal'] = history['internal'][:i+1]
break
else:
history = []
_history = remove_example_dialogue_from_history(history)
history['internal'] = []
history['visible'] = []
if args.cai_chat:
return generate_chat_html(_history, name1, name2, character)
return generate_chat_html(history['visible'], name1, name2, character)
else:
return _history
return history['visible']
def redraw_html(name1, name2):
global history
_history = remove_example_dialogue_from_history(history)
return generate_chat_html(_history, name1, name2, character)
return generate_chat_html(history['visible'], name1, name2, character)
def tokenize_dialogue(dialogue, name1, name2):
history = []
_history = []
dialogue = re.sub('<START>', '', dialogue)
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)]
if len(idx) == 0:
return history
return _history
messages = []
for i in range(len(idx)-1):
@ -402,16 +434,16 @@ if args.chat or args.cai_chat:
elif i.startswith(f'{name2}:'):
entry[1] = i[len(f'{name2}:'):].strip()
if not (len(entry[0]) == 0 and len(entry[1]) == 0):
history.append(entry)
_history.append(entry)
entry = ['', '']
return history
return _history
def save_history():
if not Path('logs').exists():
Path('logs').mkdir()
with open(Path('logs/conversation.json'), 'w') as f:
f.write(json.dumps({'data': history}, indent=2))
f.write(json.dumps({'data': history['internal']}, indent=2))
return Path('logs/conversation.json')
def upload_history(file, name1, name2):
@ -420,21 +452,22 @@ if args.chat or args.cai_chat:
try:
j = json.loads(file)
if 'data' in j:
history = j['data']
history['internal'] = j['data']
# Compatibility with Pygmalion AI's official web UI
elif 'chat' in j:
history = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
history = [['<|BEGIN-VISIBLE-CHAT|>', history[0]]] + [[history[i], history[i+1]] for i in range(1, len(history)-1, 2)]
history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', history['internal'][0]]] + [[history['internal'][i], history['internal'][i+1]] for i in range(1, len(history['internal'])-1, 2)]
else:
history = [[history[i], history[i+1]] for i in range(0, len(history)-1, 2)]
history['internal'] = [[history['internal'][i], history['internal'][i+1]] for i in range(0, len(history['internal'])-1, 2)]
except:
history = tokenize_dialogue(file, name1, name2)
history['internal'] = tokenize_dialogue(file, name1, name2)
def load_character(_character, name1, name2):
global history, character
context = ""
history = []
history['internal'] = []
history['visible'] = []
if _character != 'None':
character = _character
with open(Path(f'characters/{_character}.json'), 'r') as f:
@ -446,24 +479,24 @@ if args.chat or args.cai_chat:
context += f"Scenario: {data['world_scenario']}\n"
context = f"{context.strip()}\n<START>\n"
if 'example_dialogue' in data and data['example_dialogue'] != '':
history = tokenize_dialogue(data['example_dialogue'], name1, name2)
history['internal'] = tokenize_dialogue(data['example_dialogue'], name1, name2)
if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0:
history += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
history['visible'] += [['', data['char_greeting']]]
else:
history += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
history['visible'] += [['', "Hello there!"]]
else:
character = None
context = settings['context_pygmalion']
name2 = settings['name2_pygmalion']
_history = remove_example_dialogue_from_history(history)
if args.cai_chat:
return name2, context, generate_chat_html(_history, name1, name2, character)
return name2, context, generate_chat_html(history['visible'], name1, name2, character)
else:
return name2, context, _history
return name2, context, history['visible']
def upload_character(file, name1, name2):
global history
file = file.decode('utf-8')
data = json.loads(file)
outfile_name = data["char_name"]
@ -543,7 +576,7 @@ if args.chat or args.cai_chat:
if args.cai_chat:
upload.upload(redraw_html, [name1, name2], [display1])
else:
upload.upload(lambda : remove_example_dialogue_from_history(history), [], [display1])
upload.upload(lambda : history['visible'], [], [display1])
elif args.notebook:
with gr.Blocks(css=css, analytics_enabled=False) as interface: