Save uploaded characters as yaml

Also allow yaml characters to be uploaded directly
This commit is contained in:
oobabooga 2023-07-30 11:25:38 -07:00
parent c25602eb65
commit 6e16af34fd
3 changed files with 66 additions and 31 deletions

View File

@ -3,7 +3,6 @@ context: "Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer
greeting: |- greeting: |-
*Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air* *Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air*
Hey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started! Hey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started!
example_dialogue: |-
{{user}}: So how did you get into computer engineering? {{user}}: So how did you get into computer engineering?
{{char}}: I've always loved tinkering with technology since I was a kid. {{char}}: I've always loved tinkering with technology since I was a kid.
{{user}}: That's really impressive! {{user}}: That's really impressive!

View File

@ -27,6 +27,22 @@ from modules.utils import (
) )
def str_presenter(dumper, data):
"""
Copied from https://github.com/yaml/pyyaml/issues/240
Makes pyyaml output prettier multiline strings.
"""
if data.count('\n') > 0:
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
yaml.add_representer(str, str_presenter)
yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
def get_turn_substrings(state, instruct=False): def get_turn_substrings(state, instruct=False):
if instruct: if instruct:
if 'turn_template' not in state or state['turn_template'] == '': if 'turn_template' not in state or state['turn_template'] == '':
@ -438,18 +454,6 @@ def replace_character_names(text, name1, name2):
return text.replace('<USER>', name1).replace('<BOT>', name2) return text.replace('<USER>', name1).replace('<BOT>', name2)
def build_pygmalion_style_context(data):
context = ""
if 'char_persona' in data and data['char_persona'] != '':
context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
if 'world_scenario' in data and data['world_scenario'] != '':
context += f"Scenario: {data['world_scenario']}\n"
context = f"{context.strip()}\n<START>\n"
return context
def generate_pfp_cache(character): def generate_pfp_cache(character):
cache_folder = Path("cache") cache_folder = Path("cache")
if not cache_folder.exists(): if not cache_folder.exists():
@ -536,40 +540,72 @@ def load_character_memoized(character, name1, name2, instruct=False):
return load_character(character, name1, name2, instruct=instruct) return load_character(character, name1, name2, instruct=instruct)
def upload_character(json_file, img, tavern=False): def upload_character(file, img, tavern=False):
json_file = json_file if type(json_file) == str else json_file.decode('utf-8') decoded_file = file if type(file) == str else file.decode('utf-8')
data = json.loads(json_file) try:
outfile_name = data["char_name"] data = json.loads(decoded_file)
except:
data = yaml.safe_load(decoded_file)
if 'char_name' in data:
name = data['char_name']
greeting = data['char_greeting']
context = build_pygmalion_style_context(data)
yaml_data = generate_character_yaml(name, greeting, context)
else:
yaml_data = generate_character_yaml(data['name'], data['greeting'], data['context'])
print(repr(greeting))
print(repr(context))
print(yaml_data)
outfile_name = data['name']
i = 1 i = 1
while Path(f'characters/{outfile_name}.json').exists(): while Path(f'characters/{outfile_name}.yaml').exists():
outfile_name = f'{data["char_name"]}_{i:03d}' outfile_name = f"{data['name']}_{i:03d}"
i += 1 i += 1
if tavern: with open(Path(f'characters/{outfile_name}.yaml'), 'w', encoding='utf-8') as f:
outfile_name = f'TavernAI-{outfile_name}' f.write(yaml_data)
with open(Path(f'characters/{outfile_name}.json'), 'w', encoding='utf-8') as f:
f.write(json_file)
if img is not None: if img is not None:
img.save(Path(f'characters/{outfile_name}.png')) img.save(Path(f'characters/{outfile_name}.png'))
logger.info(f'New character saved to "characters/{outfile_name}.json".') logger.info(f'New character saved to "characters/{outfile_name}.yaml".')
return gr.update(value=outfile_name, choices=get_available_characters()) return gr.update(value=outfile_name, choices=get_available_characters())
def build_pygmalion_style_context(data):
context = ""
if 'char_persona' in data and data['char_persona'] != '':
context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
if 'world_scenario' in data and data['world_scenario'] != '':
context += f"Scenario: {data['world_scenario']}\n"
context = f"{context.strip()}\n"
return context
def upload_tavern_character(img, _json): def upload_tavern_character(img, _json):
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']} _json = {'char_name': _json['name'], 'char_persona': _json['description'], 'char_greeting': _json['first_mes'], 'example_dialogue': _json['mes_example'], 'world_scenario': _json['scenario']}
return upload_character(json.dumps(_json), img, tavern=True)
name = _json['char_name']
greeting = _json['char_greeting']
context = build_pygmalion_style_context(_json)
yaml = generate_character_yaml(name, greeting, context)
return upload_character(yaml, img, tavern=True)
def check_tavern_character(img): def check_tavern_character(img):
if "chara" not in img.info: if "chara" not in img.info:
return "Not a TavernAI card", None, None, gr.update(interactive=False) return "Not a TavernAI card", None, None, gr.update(interactive=False)
decoded_string = base64.b64decode(img.info['chara']) decoded_string = base64.b64decode(img.info['chara'])
_json = json.loads(decoded_string) _json = json.loads(decoded_string)
if "data" in _json: if "data" in _json:
_json = _json["data"] _json = _json["data"]
return _json['name'], _json['description'], _json, gr.update(interactive=True) return _json['name'], _json['description'], _json, gr.update(interactive=True)
@ -595,7 +631,7 @@ def generate_character_yaml(name, greeting, context):
} }
data = {k: v for k, v in data.items() if v} # Strip falsy data = {k: v for k, v in data.items() if v} # Strip falsy
return yaml.dump(data, sort_keys=False) return yaml.dump(data, sort_keys=False, width=float("inf"))
def generate_instruction_template_yaml(user, bot, context, turn_template): def generate_instruction_template_yaml(user, bot, context, turn_template):
@ -607,7 +643,7 @@ def generate_instruction_template_yaml(user, bot, context, turn_template):
} }
data = {k: v for k, v in data.items() if v} # Strip falsy data = {k: v for k, v in data.items() if v} # Strip falsy
return yaml.dump(data, sort_keys=False) return yaml.dump(data, sort_keys=False, width=float("inf"))
def save_character(name, greeting, context, picture, filename): def save_character(name, greeting, context, picture, filename):

View File

@ -713,9 +713,9 @@ def create_interface():
shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'], label="Upload") shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'], label="Upload")
with gr.Tab('Upload character'): with gr.Tab('Upload character'):
with gr.Tab('JSON'): with gr.Tab('YAML or JSON'):
with gr.Row(): with gr.Row():
shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json'], label='JSON File') shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json', '.yaml'], label='JSON or YAML File')
shared.gradio['upload_img_bot'] = gr.Image(type='pil', label='Profile Picture (optional)') shared.gradio['upload_img_bot'] = gr.Image(type='pil', label='Profile Picture (optional)')
shared.gradio['Submit character'] = gr.Button(value='Submit', interactive=False) shared.gradio['Submit character'] = gr.Button(value='Submit', interactive=False)