2023-03-15 11:33:26 -04:00
|
|
|
from pathlib import Path
|
|
|
|
|
2023-01-21 22:02:46 -05:00
|
|
|
import gradio as gr
|
2023-04-14 10:07:28 -04:00
|
|
|
import torch
|
2023-01-21 22:02:46 -05:00
|
|
|
|
2023-04-12 09:27:06 -04:00
|
|
|
from modules import shared
|
|
|
|
|
2023-03-15 11:33:26 -04:00
|
|
|
with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
|
|
|
|
css = f.read()
|
|
|
|
with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f:
|
|
|
|
chat_css = f.read()
|
|
|
|
with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
|
2023-03-15 12:24:54 -04:00
|
|
|
main_js = f.read()
|
|
|
|
with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
|
|
|
|
chat_js = f.read()
|
2023-03-15 11:01:32 -04:00
|
|
|
|
2023-04-18 22:36:23 -04:00
|
|
|
refresh_symbol = '\U0001f504' # 🔄
|
|
|
|
theme = gr.themes.Default(
|
|
|
|
font=['Helvetica', 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
|
|
|
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
|
|
|
|
).set(
|
|
|
|
border_color_primary='#c5c5d2',
|
2023-04-20 23:20:33 -04:00
|
|
|
button_large_padding='6px 12px',
|
|
|
|
body_text_color_subdued='#484848'
|
2023-04-18 22:36:23 -04:00
|
|
|
)
|
2023-04-06 23:15:45 -04:00
|
|
|
|
2023-04-14 10:07:28 -04:00
|
|
|
def list_model_elements():
|
|
|
|
elements = ['cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'wbits', 'groupsize', 'model_type', 'pre_layer']
|
|
|
|
for i in range(torch.cuda.device_count()):
|
|
|
|
elements.append(f'gpu_memory_{i}')
|
|
|
|
return elements
|
|
|
|
|
|
|
|
|
2023-04-12 09:27:06 -04:00
|
|
|
def list_interface_input_elements(chat=False):
|
2023-04-16 13:24:49 -04:00
|
|
|
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens']
|
2023-04-12 09:27:06 -04:00
|
|
|
if chat:
|
2023-04-14 10:07:28 -04:00
|
|
|
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template']
|
|
|
|
elements += list_model_elements()
|
2023-04-12 09:27:06 -04:00
|
|
|
return elements
|
|
|
|
|
|
|
|
|
|
|
|
def gather_interface_values(*args):
|
|
|
|
output = {}
|
|
|
|
for i, element in enumerate(shared.input_elements):
|
|
|
|
output[element] = args[i]
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
2023-04-14 10:07:28 -04:00
|
|
|
def apply_interface_values(state):
|
|
|
|
return [state[i] for i in list_interface_input_elements(chat=shared.is_chat())]
|
|
|
|
|
|
|
|
|
2023-01-21 22:02:46 -05:00
|
|
|
class ToolButton(gr.Button, gr.components.FormComponent):
|
|
|
|
"""Small button with single emoji as text, fits inside gradio forms"""
|
|
|
|
|
|
|
|
def __init__(self, **kwargs):
|
|
|
|
super().__init__(variant="tool", **kwargs)
|
|
|
|
|
|
|
|
def get_block_name(self):
|
|
|
|
return "button"
|
|
|
|
|
2023-04-06 23:15:45 -04:00
|
|
|
|
2023-01-21 22:02:46 -05:00
|
|
|
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
|
|
|
def refresh():
|
|
|
|
refresh_method()
|
|
|
|
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
|
|
|
|
|
|
|
for k, v in args.items():
|
|
|
|
setattr(refresh_component, k, v)
|
|
|
|
|
|
|
|
return gr.update(**(args or {}))
|
|
|
|
|
|
|
|
refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
|
|
|
|
refresh_button.click(
|
|
|
|
fn=refresh,
|
|
|
|
inputs=[],
|
|
|
|
outputs=[refresh_component]
|
|
|
|
)
|
|
|
|
return refresh_button
|