Upload profile pictures from the web UI

This commit is contained in:
oobabooga 2023-01-28 19:16:37 -03:00
parent 69ffef4391
commit f71531186b

View File

@ -5,9 +5,11 @@ import glob
import torch
import argparse
import json
import io
import sys
from sys import exit
from pathlib import Path
from PIL import Image
import copy
import gradio as gr
import warnings
@ -504,19 +506,27 @@ if args.chat or args.cai_chat:
else:
return name2, context, history['visible']
def upload_character(file, name1, name2):
file = file.decode('utf-8')
data = json.loads(file)
def upload_character(json_file, img, name1, name2):
json_file = json_file.decode('utf-8')
data = json.loads(json_file)
outfile_name = data["char_name"]
i = 1
while Path(f'characters/{outfile_name}.json').exists():
outfile_name = f'{data["char_name"]}_{i:03d}'
i += 1
with open(Path(f'characters/{outfile_name}.json'), 'w') as f:
f.write(file)
f.write(json_file)
if img is not None:
img = Image.open(io.BytesIO(img)).convert('RGB')
img.save(Path(f'characters/{outfile_name}.jpg'))
print(f'New character saved to "characters/{outfile_name}.json".')
return outfile_name
def upload_your_profile_picture(img):
img = Image.open(io.BytesIO(img)).convert('RGB')
img.save(Path(f'img_me.jpg'))
print(f'Profile picture saved to "img_me.jpg"')
suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else ''
with gr.Blocks(css=css+".h-\[40vh\] {height: 66.67vh} .gradio-container {max-width: 800px; margin-left: auto; margin-right: auto}", analytics_enabled=False) as interface:
if args.cai_chat:
@ -559,7 +569,16 @@ if args.chat or args.cai_chat:
download = gr.File()
save_btn = gr.Button(value="Click me")
with gr.Tab('Upload character'):
with gr.Row():
with gr.Column():
gr.Markdown('1. Select the JSON file')
upload_char = gr.File(type='binary')
with gr.Column():
gr.Markdown('2. Select your character\'s profile picture (optional)')
upload_img = gr.File(type='binary')
upload_btn = gr.Button(value="Submit")
with gr.Tab('Upload your profile picture'):
upload_img_me = gr.File(type='binary')
input_params = [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check, history_size_slider]
if args.cai_chat:
@ -579,12 +598,15 @@ if args.chat or args.cai_chat:
save_btn.click(save_history, inputs=[], outputs=[download])
character_menu.change(load_character, [character_menu, name1, name2], [name2, context, display1])
upload.upload(upload_history, [upload, name1, name2], [])
upload_char.upload(upload_character, [upload_char, name1, name2], [character_menu])
upload_btn.click(upload_character, [upload_char, upload_img, name1, name2], [character_menu])
upload_img_me.upload(upload_your_profile_picture, [upload_img_me], [])
if args.cai_chat:
upload.upload(redraw_html, [name1, name2], [display1])
upload_img_me.upload(redraw_html, [name1, name2], [display1])
else:
upload.upload(lambda : history['visible'], [], [display1])
upload_img_me.upload(lambda : history['visible'], [], [display1])
elif args.notebook:
with gr.Blocks(css=css, analytics_enabled=False) as interface: