text-generation-webui/modules/ui.py

211 lines
5.2 KiB
Python
Raw Normal View History

2023-03-15 11:33:26 -04:00
from pathlib import Path
import gradio as gr
import torch
2023-04-12 09:27:06 -04:00
from modules import shared
2023-03-15 11:33:26 -04:00
with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
css = f.read()
with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f:
chat_css = f.read()
2023-08-06 20:49:27 -04:00
with open(Path(__file__).resolve().parent / '../js/main.js', 'r') as f:
2023-03-15 12:24:54 -04:00
main_js = f.read()
2023-08-06 20:49:27 -04:00
with open(Path(__file__).resolve().parent / '../js/save_files.js', 'r') as f:
save_files_js = f.read()
refresh_symbol = '🔄'
delete_symbol = '🗑️'
save_symbol = '💾'
2023-04-18 22:36:23 -04:00
theme = gr.themes.Default(
font=['Helvetica', 'ui-sans-serif', 'system-ui', 'sans-serif'],
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
).set(
border_color_primary='#c5c5d2',
button_large_padding='6px 12px',
2023-04-21 01:47:18 -04:00
body_text_color_subdued='#484848',
background_fill_secondary='#eaeaea'
2023-04-18 22:36:23 -04:00
)
2023-08-06 20:49:27 -04:00
if Path("notification.mp3").exists():
audio_notification_js = "document.querySelector('#audio_notification audio')?.play();"
else:
audio_notification_js = ""
2023-05-03 20:43:17 -04:00
def list_model_elements():
elements = [
'loader',
'cpu_memory',
'auto_devices',
'disk',
'cpu',
'bf16',
'load_in_8bit',
'trust_remote_code',
'load_in_4bit',
'compute_dtype',
'quant_type',
'use_double_quant',
'wbits',
'groupsize',
'model_type',
'pre_layer',
'triton',
'desc_act',
'no_inject_fused_attention',
'no_inject_fused_mlp',
'no_use_cuda_fp16',
'threads',
'n_batch',
'no_mmap',
2023-07-12 10:05:13 -04:00
'low_vram',
'mlock',
'n_gpu_layers',
'n_ctx',
2023-07-24 15:37:03 -04:00
'n_gqa',
'rms_norm_eps',
'llama_cpp_seed',
'gpu_split',
'max_seq_len',
'compress_pos_emb',
'alpha_value'
]
for i in range(torch.cuda.device_count()):
elements.append(f'gpu_memory_{i}')
2023-05-25 00:14:13 -04:00
return elements
def list_interface_input_elements():
elements = [
'max_new_tokens',
'auto_max_new_tokens',
'seed',
'temperature',
'top_p',
'top_k',
'typical_p',
'epsilon_cutoff',
'eta_cutoff',
'repetition_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'min_length',
'do_sample',
'penalty_alpha',
'num_beams',
'length_penalty',
'early_stopping',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'negative_prompt',
'guidance_scale',
'add_bos_token',
'ban_eos_token',
'truncation_length',
'custom_stopping_strings',
'skip_special_tokens',
'stream',
'tfs',
'top_a',
]
if shared.args.chat:
elements += [
'character_menu',
'history',
'name1',
'name2',
'greeting',
'context',
'chat_generation_attempts',
'stop_at_newline',
'mode',
'instruction_template',
'name1_instruct',
'name2_instruct',
'context_instruct',
'turn_template',
'chat_style',
'chat-instruct_command',
]
else:
elements.append('textbox')
if not shared.args.notebook:
elements.append('output_textbox')
elements += list_model_elements()
2023-04-12 09:27:06 -04:00
return elements
def gather_interface_values(*args):
output = {}
for i, element in enumerate(list_interface_input_elements()):
2023-04-12 09:27:06 -04:00
output[element] = args[i]
if not shared.args.multi_user:
shared.persistent_interface_state = output
2023-04-12 09:27:06 -04:00
return output
def apply_interface_values(state, use_persistent=False):
if use_persistent:
state = shared.persistent_interface_state
elements = list_interface_input_elements()
if len(state) == 0:
return [gr.update() for k in elements] # Dummy, do nothing
else:
2023-07-07 12:09:14 -04:00
return [state[k] if k in state else gr.update() for k in elements]
class ToolButton(gr.Button, gr.components.IOComponent):
2023-07-25 18:49:04 -04:00
"""
Small button with single emoji as text, fits inside gradio forms
Copied from https://github.com/AUTOMATIC1111/stable-diffusion-webui
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_block_name(self):
return "button"
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_class):
2023-07-25 18:49:04 -04:00
"""
Copied from https://github.com/AUTOMATIC1111/stable-diffusion-webui
"""
def refresh():
refresh_method()
args = refreshed_args() if callable(refreshed_args) else refreshed_args
for k, v in args.items():
setattr(refresh_component, k, v)
return gr.update(**(args or {}))
refresh_button = ToolButton(value=refresh_symbol, elem_classes=elem_class)
refresh_button.click(
fn=refresh,
inputs=[],
outputs=[refresh_component]
)
return refresh_button
def create_delete_button(**kwargs):
return ToolButton(value=delete_symbol, **kwargs)
def create_save_button(**kwargs):
return ToolButton(value=save_symbol, **kwargs)