2023-08-14 10:46:07 -04:00
|
|
|
import copy
|
2023-03-15 11:33:26 -04:00
|
|
|
from pathlib import Path
|
|
|
|
|
2023-01-21 22:02:46 -05:00
|
|
|
import gradio as gr
|
2023-04-14 10:07:28 -04:00
|
|
|
import torch
|
2023-08-14 10:46:07 -04:00
|
|
|
import yaml
|
2023-10-26 22:39:51 -04:00
|
|
|
from transformers import is_torch_xpu_available
|
2023-01-21 22:02:46 -05:00
|
|
|
|
2024-01-04 23:33:32 -05:00
|
|
|
import extensions
|
2023-04-12 09:27:06 -04:00
|
|
|
from modules import shared
|
|
|
|
|
2023-09-13 16:29:00 -04:00
|
|
|
with open(Path(__file__).resolve().parent / '../css/NotoSans/stylesheet.css', 'r') as f:
|
2023-03-15 11:33:26 -04:00
|
|
|
css = f.read()
|
2023-09-13 16:29:00 -04:00
|
|
|
with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
|
|
|
|
css += f.read()
|
2023-08-06 20:49:27 -04:00
|
|
|
with open(Path(__file__).resolve().parent / '../js/main.js', 'r') as f:
|
2023-08-13 00:12:15 -04:00
|
|
|
js = f.read()
|
2023-08-06 20:49:27 -04:00
|
|
|
with open(Path(__file__).resolve().parent / '../js/save_files.js', 'r') as f:
|
2023-08-02 21:50:13 -04:00
|
|
|
save_files_js = f.read()
|
2023-08-13 21:14:09 -04:00
|
|
|
with open(Path(__file__).resolve().parent / '../js/switch_tabs.js', 'r') as f:
|
|
|
|
switch_tabs_js = f.read()
|
2023-08-16 01:39:58 -04:00
|
|
|
with open(Path(__file__).resolve().parent / '../js/show_controls.js', 'r') as f:
|
|
|
|
show_controls_js = f.read()
|
2023-11-19 00:05:17 -05:00
|
|
|
with open(Path(__file__).resolve().parent / '../js/update_big_picture.js', 'r') as f:
|
|
|
|
update_big_picture_js = f.read()
|
2023-03-15 11:01:32 -04:00
|
|
|
|
2023-07-03 23:03:30 -04:00
|
|
|
refresh_symbol = '🔄'
|
2023-05-20 20:48:45 -04:00
|
|
|
delete_symbol = '🗑️'
|
|
|
|
save_symbol = '💾'
|
|
|
|
|
2023-04-18 22:36:23 -04:00
|
|
|
theme = gr.themes.Default(
|
2023-09-13 16:29:00 -04:00
|
|
|
font=['Noto Sans', 'Helvetica', 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
2023-04-18 22:36:23 -04:00
|
|
|
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
|
|
|
|
).set(
|
|
|
|
border_color_primary='#c5c5d2',
|
2023-04-20 23:20:33 -04:00
|
|
|
button_large_padding='6px 12px',
|
2023-04-21 01:47:18 -04:00
|
|
|
body_text_color_subdued='#484848',
|
|
|
|
background_fill_secondary='#eaeaea'
|
2023-04-18 22:36:23 -04:00
|
|
|
)
|
2023-04-06 23:15:45 -04:00
|
|
|
|
2023-08-06 20:49:27 -04:00
|
|
|
if Path("notification.mp3").exists():
|
|
|
|
audio_notification_js = "document.querySelector('#audio_notification audio')?.play();"
|
|
|
|
else:
|
|
|
|
audio_notification_js = ""
|
|
|
|
|
2023-05-03 20:43:17 -04:00
|
|
|
|
2023-04-14 10:07:28 -04:00
|
|
|
def list_model_elements():
|
2023-07-03 23:03:30 -04:00
|
|
|
elements = [
|
|
|
|
'loader',
|
2023-09-25 13:28:35 -04:00
|
|
|
'filter_by_loader',
|
2023-07-03 23:03:30 -04:00
|
|
|
'cpu_memory',
|
|
|
|
'auto_devices',
|
|
|
|
'disk',
|
|
|
|
'cpu',
|
|
|
|
'bf16',
|
|
|
|
'load_in_8bit',
|
|
|
|
'trust_remote_code',
|
2023-11-16 22:45:05 -05:00
|
|
|
'no_use_fast',
|
2023-11-04 12:59:33 -04:00
|
|
|
'use_flash_attention_2',
|
2023-07-03 23:03:30 -04:00
|
|
|
'load_in_4bit',
|
|
|
|
'compute_dtype',
|
|
|
|
'quant_type',
|
|
|
|
'use_double_quant',
|
|
|
|
'wbits',
|
|
|
|
'groupsize',
|
|
|
|
'model_type',
|
|
|
|
'pre_layer',
|
|
|
|
'triton',
|
|
|
|
'desc_act',
|
|
|
|
'no_inject_fused_attention',
|
|
|
|
'no_inject_fused_mlp',
|
|
|
|
'no_use_cuda_fp16',
|
2023-08-12 02:26:58 -04:00
|
|
|
'disable_exllama',
|
2023-12-15 09:46:13 -05:00
|
|
|
'disable_exllamav2',
|
2023-08-24 15:27:36 -04:00
|
|
|
'cfg_cache',
|
2023-11-02 14:23:04 -04:00
|
|
|
'no_flash_attn',
|
2023-12-17 10:08:33 -05:00
|
|
|
'num_experts_per_token',
|
2023-11-02 14:23:04 -04:00
|
|
|
'cache_8bit',
|
2023-07-03 23:03:30 -04:00
|
|
|
'threads',
|
2023-10-02 00:27:04 -04:00
|
|
|
'threads_batch',
|
2023-07-03 23:03:30 -04:00
|
|
|
'n_batch',
|
|
|
|
'no_mmap',
|
|
|
|
'mlock',
|
2023-10-22 15:22:06 -04:00
|
|
|
'no_mul_mat_q',
|
2023-07-03 23:03:30 -04:00
|
|
|
'n_gpu_layers',
|
2023-08-18 11:03:34 -04:00
|
|
|
'tensor_split',
|
2023-07-03 23:03:30 -04:00
|
|
|
'n_ctx',
|
|
|
|
'gpu_split',
|
|
|
|
'max_seq_len',
|
2023-07-04 00:13:16 -04:00
|
|
|
'compress_pos_emb',
|
2023-08-25 09:53:37 -04:00
|
|
|
'alpha_value',
|
2023-09-26 21:05:00 -04:00
|
|
|
'rope_freq_base',
|
|
|
|
'numa',
|
2023-11-07 17:35:48 -05:00
|
|
|
'logits_all',
|
2023-12-19 13:22:21 -05:00
|
|
|
'no_offload_kqv',
|
2023-12-19 15:30:53 -05:00
|
|
|
'tensorcores',
|
2023-12-18 19:23:16 -05:00
|
|
|
'hqq_backend',
|
2023-07-03 23:03:30 -04:00
|
|
|
]
|
2023-10-26 22:39:51 -04:00
|
|
|
if is_torch_xpu_available():
|
|
|
|
for i in range(torch.xpu.device_count()):
|
|
|
|
elements.append(f'gpu_memory_{i}')
|
|
|
|
else:
|
|
|
|
for i in range(torch.cuda.device_count()):
|
|
|
|
elements.append(f'gpu_memory_{i}')
|
2023-05-25 00:14:13 -04:00
|
|
|
|
2023-04-14 10:07:28 -04:00
|
|
|
return elements
|
|
|
|
|
|
|
|
|
2023-07-03 23:03:30 -04:00
|
|
|
def list_interface_input_elements():
|
|
|
|
elements = [
|
|
|
|
'max_new_tokens',
|
2023-08-02 13:52:20 -04:00
|
|
|
'auto_max_new_tokens',
|
2023-08-29 16:44:31 -04:00
|
|
|
'max_tokens_second',
|
2023-12-24 12:17:40 -05:00
|
|
|
'max_updates_second',
|
2023-07-03 23:03:30 -04:00
|
|
|
'seed',
|
|
|
|
'temperature',
|
2023-11-04 12:09:07 -04:00
|
|
|
'temperature_last',
|
2024-01-07 15:03:47 -05:00
|
|
|
'dynamic_temperature',
|
2024-01-08 21:28:35 -05:00
|
|
|
'dynatemp_low',
|
|
|
|
'dynatemp_high',
|
|
|
|
'dynatemp_exponent',
|
2023-07-03 23:03:30 -04:00
|
|
|
'top_p',
|
2023-11-02 15:32:51 -04:00
|
|
|
'min_p',
|
2023-07-03 23:03:30 -04:00
|
|
|
'top_k',
|
|
|
|
'typical_p',
|
|
|
|
'epsilon_cutoff',
|
|
|
|
'eta_cutoff',
|
|
|
|
'repetition_penalty',
|
2023-10-25 11:10:28 -04:00
|
|
|
'presence_penalty',
|
|
|
|
'frequency_penalty',
|
2023-07-03 23:03:30 -04:00
|
|
|
'repetition_penalty_range',
|
|
|
|
'encoder_repetition_penalty',
|
|
|
|
'no_repeat_ngram_size',
|
|
|
|
'min_length',
|
|
|
|
'do_sample',
|
|
|
|
'penalty_alpha',
|
|
|
|
'num_beams',
|
|
|
|
'length_penalty',
|
|
|
|
'early_stopping',
|
|
|
|
'mirostat_mode',
|
|
|
|
'mirostat_tau',
|
|
|
|
'mirostat_eta',
|
2023-09-24 17:05:24 -04:00
|
|
|
'grammar_string',
|
2023-08-06 16:22:48 -04:00
|
|
|
'negative_prompt',
|
|
|
|
'guidance_scale',
|
2023-07-03 23:03:30 -04:00
|
|
|
'add_bos_token',
|
|
|
|
'ban_eos_token',
|
2023-09-15 17:27:27 -04:00
|
|
|
'custom_token_bans',
|
2023-07-03 23:03:30 -04:00
|
|
|
'truncation_length',
|
|
|
|
'custom_stopping_strings',
|
|
|
|
'skip_special_tokens',
|
|
|
|
'stream',
|
|
|
|
'tfs',
|
|
|
|
'top_a',
|
|
|
|
]
|
|
|
|
|
2023-08-13 00:12:15 -04:00
|
|
|
# Chat elements
|
|
|
|
elements += [
|
|
|
|
'textbox',
|
2023-08-14 10:46:07 -04:00
|
|
|
'start_with',
|
2023-08-13 00:12:15 -04:00
|
|
|
'character_menu',
|
|
|
|
'history',
|
|
|
|
'name1',
|
|
|
|
'name2',
|
|
|
|
'greeting',
|
|
|
|
'context',
|
|
|
|
'mode',
|
2023-11-07 18:02:58 -05:00
|
|
|
'custom_system_message',
|
2023-12-12 15:23:14 -05:00
|
|
|
'instruction_template_str',
|
|
|
|
'chat_template_str',
|
2023-08-13 00:12:15 -04:00
|
|
|
'chat_style',
|
|
|
|
'chat-instruct_command',
|
|
|
|
]
|
2023-04-24 02:05:47 -04:00
|
|
|
|
2023-08-13 00:12:15 -04:00
|
|
|
# Notebook/default elements
|
|
|
|
elements += [
|
|
|
|
'textbox-notebook',
|
|
|
|
'textbox-default',
|
2023-08-14 10:46:07 -04:00
|
|
|
'output_textbox',
|
|
|
|
'prompt_menu-default',
|
|
|
|
'prompt_menu-notebook',
|
2023-08-13 00:12:15 -04:00
|
|
|
]
|
|
|
|
|
|
|
|
# Model elements
|
2023-04-14 10:07:28 -04:00
|
|
|
elements += list_model_elements()
|
2023-08-13 00:12:15 -04:00
|
|
|
|
2023-04-12 09:27:06 -04:00
|
|
|
return elements
|
|
|
|
|
|
|
|
|
|
|
|
def gather_interface_values(*args):
|
|
|
|
output = {}
|
2023-07-03 23:03:30 -04:00
|
|
|
for i, element in enumerate(list_interface_input_elements()):
|
2023-04-12 09:27:06 -04:00
|
|
|
output[element] = args[i]
|
2023-04-24 02:05:47 -04:00
|
|
|
|
2023-07-03 23:03:30 -04:00
|
|
|
if not shared.args.multi_user:
|
|
|
|
shared.persistent_interface_state = output
|
|
|
|
|
2023-04-12 09:27:06 -04:00
|
|
|
return output
|
|
|
|
|
|
|
|
|
2023-04-24 02:05:47 -04:00
|
|
|
def apply_interface_values(state, use_persistent=False):
|
|
|
|
if use_persistent:
|
|
|
|
state = shared.persistent_interface_state
|
|
|
|
|
2023-07-03 23:03:30 -04:00
|
|
|
elements = list_interface_input_elements()
|
2023-04-24 02:05:47 -04:00
|
|
|
if len(state) == 0:
|
|
|
|
return [gr.update() for k in elements] # Dummy, do nothing
|
|
|
|
else:
|
2023-07-07 12:09:14 -04:00
|
|
|
return [state[k] if k in state else gr.update() for k in elements]
|
2023-04-14 10:07:28 -04:00
|
|
|
|
|
|
|
|
2024-01-09 07:20:10 -05:00
|
|
|
def save_settings(state, preset, extensions_list, show_controls, theme_state):
|
2023-08-14 10:46:07 -04:00
|
|
|
output = copy.deepcopy(shared.settings)
|
2023-09-14 06:44:17 -04:00
|
|
|
exclude = ['name2', 'greeting', 'context', 'turn_template']
|
2023-08-14 10:46:07 -04:00
|
|
|
for k in state:
|
|
|
|
if k in shared.settings and k not in exclude:
|
|
|
|
output[k] = state[k]
|
|
|
|
|
|
|
|
output['preset'] = preset
|
|
|
|
output['prompt-default'] = state['prompt_menu-default']
|
|
|
|
output['prompt-notebook'] = state['prompt_menu-notebook']
|
|
|
|
output['character'] = state['character_menu']
|
2024-01-04 23:33:32 -05:00
|
|
|
output['default_extensions'] = extensions_list
|
2023-08-14 10:46:07 -04:00
|
|
|
output['seed'] = int(output['seed'])
|
2023-08-16 10:03:53 -04:00
|
|
|
output['show_controls'] = show_controls
|
2024-01-09 07:20:10 -05:00
|
|
|
output['dark_theme'] = True if theme_state == 'dark' else False
|
2023-08-14 10:46:07 -04:00
|
|
|
|
2024-01-04 23:33:32 -05:00
|
|
|
# Save extension values in the UI
|
|
|
|
for extension_name in extensions_list:
|
|
|
|
extension = getattr(extensions, extension_name).script
|
|
|
|
if hasattr(extension, 'params'):
|
|
|
|
params = getattr(extension, 'params')
|
|
|
|
for param in params:
|
|
|
|
_id = f"{extension_name}-{param}"
|
|
|
|
output[_id] = params[param]
|
|
|
|
|
2023-08-14 10:46:07 -04:00
|
|
|
return yaml.dump(output, sort_keys=False, width=float("inf"))
|
|
|
|
|
|
|
|
|
2023-09-26 08:44:04 -04:00
|
|
|
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_class, interactive=True):
|
2023-07-25 18:49:04 -04:00
|
|
|
"""
|
|
|
|
Copied from https://github.com/AUTOMATIC1111/stable-diffusion-webui
|
|
|
|
"""
|
2023-01-21 22:02:46 -05:00
|
|
|
def refresh():
|
|
|
|
refresh_method()
|
|
|
|
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
|
|
|
|
|
|
|
for k, v in args.items():
|
|
|
|
setattr(refresh_component, k, v)
|
|
|
|
|
|
|
|
return gr.update(**(args or {}))
|
|
|
|
|
2023-10-10 21:20:49 -04:00
|
|
|
refresh_button = gr.Button(refresh_symbol, elem_classes=elem_class, interactive=interactive)
|
2023-01-21 22:02:46 -05:00
|
|
|
refresh_button.click(
|
2024-01-08 20:49:54 -05:00
|
|
|
fn=lambda: {k: tuple(v) if type(k) is list else v for k, v in refresh().items()},
|
2023-01-21 22:02:46 -05:00
|
|
|
inputs=[],
|
|
|
|
outputs=[refresh_component]
|
|
|
|
)
|
2023-07-03 23:03:30 -04:00
|
|
|
|
2023-01-21 22:02:46 -05:00
|
|
|
return refresh_button
|