2023-01-21 22:02:46 -05:00
import gc
2023-01-28 17:16:37 -05:00
import io
2023-02-10 13:40:03 -05:00
import json
2023-02-01 07:57:27 -05:00
import os
2023-02-10 13:40:03 -05:00
import re
import sys
import time
2023-02-13 13:25:16 -05:00
import zipfile
2023-01-07 14:33:43 -05:00
from pathlib import Path
2023-02-10 13:40:03 -05:00
2022-12-21 11:27:31 -05:00
import gradio as gr
2023-02-13 13:25:16 -05:00
import numpy as np
2023-02-10 13:40:03 -05:00
import torch
2023-01-21 22:02:46 -05:00
import transformers
2023-02-10 13:40:03 -05:00
from PIL import Image
from transformers import AutoConfig
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
2023-02-23 10:05:25 -05:00
import modules . chat as chat
import modules . extensions as extensions_module
import modules . shared as shared
from modules . extensions import extension_state
from modules . extensions import load_extensions
from modules . extensions import update_extensions_parameters
2023-01-21 22:02:46 -05:00
from modules . html_generator import *
2023-02-23 10:05:25 -05:00
from modules . prompt import generate_reply
2023-02-10 13:40:03 -05:00
from modules . ui import *
2022-12-21 11:27:31 -05:00
2023-01-15 13:23:41 -05:00
transformers . logging . set_verbosity_error ( )
2023-02-23 10:05:25 -05:00
if ( shared . args . chat or shared . args . cai_chat ) and not shared . args . no_stream :
2023-01-25 12:37:41 -05:00
print ( " Warning: chat mode currently becomes somewhat slower with text streaming on. \n Consider starting the web UI with the --no-stream option. \n " )
2023-01-29 12:27:22 -05:00
2023-01-15 13:23:41 -05:00
settings = {
' max_new_tokens ' : 200 ,
' max_new_tokens_min ' : 1 ,
' max_new_tokens_max ' : 2000 ,
' preset ' : ' NovelAI-Sphinx Moth ' ,
' name1 ' : ' Person 1 ' ,
' name2 ' : ' Person 2 ' ,
' context ' : ' This is a conversation between two people. ' ,
' prompt ' : ' Common sense questions and answers \n \n Question: \n Factual answer: ' ,
' prompt_gpt4chan ' : ' ----- \n --- 865467536 \n Input text \n --- 865467537 \n ' ,
' stop_at_newline ' : True ,
2023-02-15 08:18:50 -05:00
' chat_prompt_size ' : 2048 ,
' chat_prompt_size_min ' : 0 ,
' chat_prompt_size_max ' : 2048 ,
2023-01-19 16:58:45 -05:00
' preset_pygmalion ' : ' Pygmalion ' ,
' name1_pygmalion ' : ' You ' ,
' name2_pygmalion ' : ' Kawaii ' ,
2023-01-21 22:49:59 -05:00
' context_pygmalion ' : " Kawaii ' s persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes. \n <START> " ,
2023-01-19 14:46:46 -05:00
' stop_at_newline_pygmalion ' : False ,
2023-01-15 13:23:41 -05:00
}
2023-02-23 10:05:25 -05:00
if shared . args . settings is not None and Path ( shared . args . settings ) . exists ( ) :
new_settings = json . loads ( open ( Path ( shared . args . settings ) , ' r ' ) . read ( ) )
2023-01-16 14:35:45 -05:00
for item in new_settings :
2023-01-28 21:21:40 -05:00
settings [ item ] = new_settings [ item ]
2023-01-14 22:39:51 -05:00
2023-02-23 10:05:25 -05:00
if shared . args . flexgen :
2023-02-21 19:00:06 -05:00
from flexgen . flex_opt import ( Policy , OptLM , TorchDevice , TorchDisk , TorchMixedDevice , CompressionConfig , Env , Task , get_opt_config )
2023-02-23 10:05:25 -05:00
if shared . args . deepspeed :
2023-02-01 07:57:27 -05:00
import deepspeed
from transformers . deepspeed import HfDeepSpeedConfig , is_deepspeed_zero3_enabled
2023-02-02 08:39:37 -05:00
from modules . deepspeed_parameters import generate_ds_config
2023-02-01 07:57:27 -05:00
# Distributed setup
2023-02-23 10:05:25 -05:00
local_rank = shared . args . local_rank if shared . args . local_rank is not None else int ( os . getenv ( " LOCAL_RANK " , " 0 " ) )
2023-02-01 07:57:27 -05:00
world_size = int ( os . getenv ( " WORLD_SIZE " , " 1 " ) )
torch . cuda . set_device ( local_rank )
deepspeed . init_distributed ( )
2023-02-23 10:05:25 -05:00
ds_config = generate_ds_config ( shared . args . bf16 , 1 * world_size , shared . args . nvme_offload_dir )
2023-02-01 07:57:27 -05:00
dschf = HfDeepSpeedConfig ( ds_config ) # Keep this object alive for the Transformers integration
2023-02-23 10:05:25 -05:00
if shared . args . picture and ( shared . args . cai_chat or shared . args . chat ) :
2023-02-14 18:38:21 -05:00
import modules . bot_picture as bot_picture
2022-12-21 11:27:31 -05:00
def load_model ( model_name ) :
2023-01-05 23:41:52 -05:00
print ( f " Loading { model_name } ... " )
2022-12-21 11:27:31 -05:00
t0 = time . time ( )
2023-01-05 23:41:52 -05:00
2023-01-10 21:16:33 -05:00
# Default settings
2023-02-23 10:05:25 -05:00
if not ( shared . args . cpu or shared . args . load_in_8bit or shared . args . auto_devices or shared . args . disk or shared . args . gpu_memory is not None or shared . args . cpu_memory is not None or shared . args . deepspeed or shared . args . flexgen ) :
if any ( size in shared . model_name . lower ( ) for size in ( ' 13b ' , ' 20b ' , ' 30b ' ) ) :
model = AutoModelForCausalLM . from_pretrained ( Path ( f " models/ { shared . model_name } " ) , device_map = ' auto ' , load_in_8bit = True )
2023-01-10 21:16:33 -05:00
else :
2023-02-23 10:05:25 -05:00
model = AutoModelForCausalLM . from_pretrained ( Path ( f " models/ { shared . model_name } " ) , low_cpu_mem_usage = True , torch_dtype = torch . bfloat16 if shared . args . bf16 else torch . float16 ) . cuda ( )
2023-02-01 07:57:27 -05:00
2023-02-21 19:00:06 -05:00
# FlexGen
2023-02-23 10:05:25 -05:00
elif shared . args . flexgen :
2023-02-21 19:00:06 -05:00
gpu = TorchDevice ( " cuda:0 " )
cpu = TorchDevice ( " cpu " )
2023-02-23 10:05:25 -05:00
disk = TorchDisk ( shared . args . disk_cache_dir )
2023-02-21 19:00:06 -05:00
env = Env ( gpu = gpu , cpu = cpu , disk = disk , mixed = TorchMixedDevice ( [ gpu , cpu , disk ] ) )
# Offloading policy
policy = Policy ( 1 , 1 ,
2023-02-23 10:05:25 -05:00
shared . args . percent [ 0 ] , shared . args . percent [ 1 ] ,
shared . args . percent [ 2 ] , shared . args . percent [ 3 ] ,
shared . args . percent [ 4 ] , shared . args . percent [ 5 ] ,
2023-02-21 19:00:06 -05:00
overlap = True , sep_layer = True , pin_weight = True ,
cpu_cache_compute = False , attn_sparsity = 1.0 ,
2023-02-23 10:05:25 -05:00
compress_weight = shared . args . compress_weight ,
2023-02-21 19:00:06 -05:00
comp_weight_config = CompressionConfig (
num_bits = 4 , group_size = 64 ,
group_dim = 0 , symmetric = False ) ,
compress_cache = False ,
comp_cache_config = CompressionConfig (
num_bits = 4 , group_size = 64 ,
group_dim = 2 , symmetric = False ) )
2023-02-23 10:05:25 -05:00
opt_config = get_opt_config ( f " facebook/ { shared . model_name } " )
2023-02-21 19:00:06 -05:00
model = OptLM ( opt_config , env , " models " , policy )
model . init_all_weights ( )
2023-02-01 07:57:27 -05:00
# DeepSpeed ZeRO-3
2023-02-23 10:05:25 -05:00
elif shared . args . deepspeed :
model = AutoModelForCausalLM . from_pretrained ( Path ( f " models/ { shared . model_name } " ) , torch_dtype = torch . bfloat16 if shared . args . bf16 else torch . float16 )
2023-02-02 10:15:44 -05:00
model = deepspeed . initialize ( model = model , config_params = ds_config , model_parameters = None , optimizer = None , lr_scheduler = None ) [ 0 ]
2023-02-01 07:57:27 -05:00
model . module . eval ( ) # Inference
print ( f " DeepSpeed ZeRO-3 is enabled: { is_deepspeed_zero3_enabled ( ) } " )
2023-01-10 21:16:33 -05:00
# Custom
2023-01-06 00:54:33 -05:00
else :
2023-01-10 21:39:50 -05:00
command = " AutoModelForCausalLM.from_pretrained "
2023-02-07 22:19:20 -05:00
params = [ " low_cpu_mem_usage=True " ]
2023-02-23 10:05:25 -05:00
if not shared . args . cpu and not torch . cuda . is_available ( ) :
2023-02-11 21:17:06 -05:00
print ( " Warning: no GPU has been detected. \n Falling back to CPU mode. \n " )
2023-02-23 10:05:25 -05:00
shared . args . cpu = True
2023-01-10 21:16:33 -05:00
2023-02-23 10:05:25 -05:00
if shared . args . cpu :
2023-02-07 22:19:20 -05:00
params . append ( " low_cpu_mem_usage=True " )
params . append ( " torch_dtype=torch.float32 " )
2023-01-09 14:28:04 -05:00
else :
2023-02-07 22:19:20 -05:00
params . append ( " device_map= ' auto ' " )
2023-02-23 10:05:25 -05:00
params . append ( " load_in_8bit=True " if shared . args . load_in_8bit else " torch_dtype=torch.bfloat16 " if shared . args . bf16 else " torch_dtype=torch.float16 " )
2023-01-30 12:17:12 -05:00
2023-02-23 10:05:25 -05:00
if shared . args . gpu_memory :
params . append ( f " max_memory= {{ 0: ' { shared . args . gpu_memory or ' 99 ' } GiB ' , ' cpu ' : ' { shared . args . cpu_memory or ' 99 ' } GiB ' }} " )
elif not shared . args . load_in_8bit :
2023-01-31 18:47:05 -05:00
total_mem = ( torch . cuda . get_device_properties ( 0 ) . total_memory / ( 1024 * 1024 ) )
suggestion = round ( ( total_mem - 1000 ) / 1000 ) * 1000
if total_mem - suggestion < 800 :
suggestion - = 1000
suggestion = int ( round ( suggestion / 1000 ) )
print ( f " \033 [1;32;1mAuto-assiging --gpu-memory { suggestion } for your GPU to try to prevent out-of-memory errors. \n You can manually set other values. \033 [0;37;0m " )
2023-02-23 10:05:25 -05:00
params . append ( f " max_memory= {{ 0: ' { suggestion } GiB ' , ' cpu ' : ' { shared . args . cpu_memory or ' 99 ' } GiB ' }} " )
if shared . args . disk :
params . append ( f " offload_folder= ' { shared . args . disk_cache_dir } ' " )
2023-01-10 21:16:33 -05:00
2023-02-23 10:05:25 -05:00
command = f " { command } (Path(f ' models/ { shared . model_name } ' ), { ' , ' . join ( set ( params ) ) } ) "
2023-01-10 21:16:33 -05:00
model = eval ( command )
2022-12-21 11:27:31 -05:00
2023-01-06 00:54:33 -05:00
# Loading the tokenizer
2023-02-23 10:05:25 -05:00
if shared . model_name . lower ( ) . startswith ( ( ' gpt4chan ' , ' gpt-4chan ' , ' 4chan ' ) ) and Path ( f " models/gpt-j-6B/ " ) . exists ( ) :
2023-01-07 14:33:43 -05:00
tokenizer = AutoTokenizer . from_pretrained ( Path ( " models/gpt-j-6B/ " ) )
2022-12-21 11:27:31 -05:00
else :
2023-02-23 10:05:25 -05:00
tokenizer = AutoTokenizer . from_pretrained ( Path ( f " models/ { shared . model_name } / " ) )
2023-01-16 11:43:23 -05:00
tokenizer . truncation_side = ' left '
2022-12-21 11:27:31 -05:00
2023-01-06 00:06:59 -05:00
print ( f " Loaded the model in { ( time . time ( ) - t0 ) : .2f } seconds. " )
2022-12-21 11:27:31 -05:00
return model , tokenizer
2023-02-13 13:25:16 -05:00
def load_soft_prompt ( name ) :
if name == ' None ' :
2023-02-23 10:05:25 -05:00
shared . soft_prompt = False
shared . soft_prompt_tensor = None
2023-02-13 13:25:16 -05:00
else :
with zipfile . ZipFile ( Path ( f ' softprompts/ { name } .zip ' ) ) as zf :
zf . extract ( ' tensor.npy ' )
2023-02-18 23:48:23 -05:00
zf . extract ( ' meta.json ' )
j = json . loads ( open ( ' meta.json ' , ' r ' ) . read ( ) )
print ( f " \n Loading the softprompt \" { name } \" . " )
for field in j :
if field != ' name ' :
if type ( j [ field ] ) is list :
print ( f " { field } : { ' , ' . join ( j [ field ] ) } " )
else :
print ( f " { field } : { j [ field ] } " )
print ( )
2023-02-13 13:25:16 -05:00
tensor = np . load ( ' tensor.npy ' )
2023-02-18 23:48:23 -05:00
Path ( ' tensor.npy ' ) . unlink ( )
Path ( ' meta.json ' ) . unlink ( )
2023-02-23 10:05:25 -05:00
tensor = torch . Tensor ( tensor ) . to ( device = shared . model . device , dtype = shared . model . dtype )
2023-02-13 13:25:16 -05:00
tensor = torch . reshape ( tensor , ( 1 , tensor . shape [ 0 ] , tensor . shape [ 1 ] ) )
2023-02-23 10:05:25 -05:00
shared . soft_prompt = True
shared . soft_prompt_tensor = tensor
2023-02-13 13:25:16 -05:00
return name
2023-02-13 16:48:32 -05:00
def upload_soft_prompt ( file ) :
2023-02-13 13:25:16 -05:00
with zipfile . ZipFile ( io . BytesIO ( file ) ) as zf :
zf . extract ( ' meta.json ' )
j = json . loads ( open ( ' meta.json ' , ' r ' ) . read ( ) )
name = j [ ' name ' ]
2023-02-18 23:48:23 -05:00
Path ( ' meta.json ' ) . unlink ( )
2023-02-13 13:25:16 -05:00
with open ( Path ( f ' softprompts/ { name } .zip ' ) , ' wb ' ) as f :
f . write ( file )
return name
2023-02-07 20:08:21 -05:00
def load_model_wrapper ( selected_model ) :
2023-02-23 10:05:25 -05:00
if selected_model != shared . model_name :
shared . model_name = selected_model
model = shared . tokenizer = None
if not shared . args . cpu :
2023-02-07 20:08:21 -05:00
gc . collect ( )
torch . cuda . empty_cache ( )
2023-02-23 10:05:25 -05:00
shared . model , shared . tokenizer = load_model ( shared . model_name )
2023-02-07 20:08:21 -05:00
2023-02-12 07:36:27 -05:00
return selected_model
2023-02-07 20:08:21 -05:00
def load_preset_values ( preset_menu , return_dict = False ) :
2023-02-07 22:19:20 -05:00
generate_params = {
2023-02-07 20:08:21 -05:00
' do_sample ' : True ,
' temperature ' : 1 ,
' top_p ' : 1 ,
' typical_p ' : 1 ,
' repetition_penalty ' : 1 ,
' top_k ' : 50 ,
2023-02-07 21:11:04 -05:00
' num_beams ' : 1 ,
2023-02-11 12:48:12 -05:00
' penalty_alpha ' : 0 ,
2023-02-07 21:11:04 -05:00
' min_length ' : 0 ,
' length_penalty ' : 1 ,
' no_repeat_ngram_size ' : 0 ,
' early_stopping ' : False ,
2023-02-07 20:08:21 -05:00
}
with open ( Path ( f ' presets/ { preset_menu } .txt ' ) , ' r ' ) as infile :
preset = infile . read ( )
2023-02-11 12:54:29 -05:00
for i in preset . splitlines ( ) :
i = i . rstrip ( ' , ' ) . strip ( ) . split ( ' = ' )
2023-02-07 20:08:21 -05:00
if len ( i ) == 2 and i [ 0 ] . strip ( ) != ' tokens ' :
2023-02-07 22:19:20 -05:00
generate_params [ i [ 0 ] . strip ( ) ] = eval ( i [ 1 ] . strip ( ) )
2023-02-07 20:08:21 -05:00
2023-02-07 22:19:20 -05:00
generate_params [ ' temperature ' ] = min ( 1.99 , generate_params [ ' temperature ' ] )
2023-02-07 20:08:21 -05:00
if return_dict :
2023-02-07 22:19:20 -05:00
return generate_params
2023-02-07 20:08:21 -05:00
else :
2023-02-11 12:48:12 -05:00
return generate_params [ ' do_sample ' ] , generate_params [ ' temperature ' ] , generate_params [ ' top_p ' ] , generate_params [ ' typical_p ' ] , generate_params [ ' repetition_penalty ' ] , generate_params [ ' top_k ' ] , generate_params [ ' min_length ' ] , generate_params [ ' no_repeat_ngram_size ' ] , generate_params [ ' num_beams ' ] , generate_params [ ' penalty_alpha ' ] , generate_params [ ' length_penalty ' ] , generate_params [ ' early_stopping ' ]
2023-02-07 20:08:21 -05:00
def get_available_models ( ) :
2023-02-22 09:38:16 -05:00
return sorted ( [ item . name for item in list ( Path ( ' models/ ' ) . glob ( ' * ' ) ) if not item . name . endswith ( ( ' .txt ' , ' -np ' ) ) ] , key = str . lower )
2023-02-07 20:08:21 -05:00
def get_available_presets ( ) :
return sorted ( set ( map ( lambda x : ' . ' . join ( str ( x . name ) . split ( ' . ' ) [ : - 1 ] ) , Path ( ' presets ' ) . glob ( ' *.txt ' ) ) ) , key = str . lower )
def get_available_characters ( ) :
return [ " None " ] + sorted ( set ( map ( lambda x : ' . ' . join ( str ( x . name ) . split ( ' . ' ) [ : - 1 ] ) , Path ( ' characters ' ) . glob ( ' *.json ' ) ) ) , key = str . lower )
def get_available_extensions ( ) :
return sorted ( set ( map ( lambda x : x . parts [ 1 ] , Path ( ' extensions ' ) . glob ( ' */script.py ' ) ) ) , key = str . lower )
2023-02-13 13:25:16 -05:00
def get_available_softprompts ( ) :
return [ " None " ] + sorted ( set ( map ( lambda x : ' . ' . join ( str ( x . name ) . split ( ' . ' ) [ : - 1 ] ) , Path ( ' softprompts ' ) . glob ( ' *.zip ' ) ) ) , key = str . lower )
2023-01-29 07:48:18 -05:00
def create_extensions_block ( ) :
extensions_ui_elements = [ ]
default_values = [ ]
2023-02-23 10:05:25 -05:00
if not ( shared . args . chat or shared . args . cai_chat ) :
2023-02-15 18:55:32 -05:00
gr . Markdown ( ' ## Extensions parameters ' )
2023-01-29 07:48:18 -05:00
for ext in sorted ( extension_state , key = lambda x : extension_state [ x ] [ 1 ] ) :
if extension_state [ ext ] [ 0 ] == True :
2023-02-23 10:05:25 -05:00
params = extensions_module . get_params ( ext )
2023-01-29 07:48:18 -05:00
for param in params :
_id = f " { ext } - { param } "
default_value = settings [ _id ] if _id in settings else params [ param ]
default_values . append ( default_value )
if type ( params [ param ] ) == str :
extensions_ui_elements . append ( gr . Textbox ( value = default_value , label = f " { ext } - { param } " ) )
elif type ( params [ param ] ) in [ int , float ] :
extensions_ui_elements . append ( gr . Number ( value = default_value , label = f " { ext } - { param } " ) )
elif type ( params [ param ] ) == bool :
extensions_ui_elements . append ( gr . Checkbox ( value = default_value , label = f " { ext } - { param } " ) )
update_extensions_parameters ( * default_values )
btn_extensions = gr . Button ( " Apply " )
btn_extensions . click ( update_extensions_parameters , [ * extensions_ui_elements ] , [ ] )
2023-02-07 20:08:21 -05:00
def create_settings_menus ( ) :
2023-02-23 10:05:25 -05:00
generate_params = load_preset_values ( settings [ f ' preset { suffix } ' ] if not shared . args . flexgen else ' Naive ' , return_dict = True )
2023-01-21 22:49:59 -05:00
2023-02-07 20:08:21 -05:00
with gr . Row ( ) :
with gr . Column ( ) :
with gr . Row ( ) :
2023-02-23 10:05:25 -05:00
model_menu = gr . Dropdown ( choices = available_models , value = shared . model_name , label = ' Model ' )
2023-02-07 20:08:21 -05:00
create_refresh_button ( model_menu , lambda : None , lambda : { " choices " : get_available_models ( ) } , " refresh-button " )
with gr . Column ( ) :
with gr . Row ( ) :
2023-02-23 10:05:25 -05:00
preset_menu = gr . Dropdown ( choices = available_presets , value = settings [ f ' preset { suffix } ' ] if not shared . args . flexgen else ' Naive ' , label = ' Generation parameters preset ' )
2023-02-07 20:08:21 -05:00
create_refresh_button ( preset_menu , lambda : None , lambda : { " choices " : get_available_presets ( ) } , " refresh-button " )
2023-01-21 22:49:59 -05:00
2023-02-16 19:55:20 -05:00
with gr . Accordion ( " Custom generation parameters " , open = False , elem_id = " accordion " ) :
2023-02-07 20:08:21 -05:00
with gr . Row ( ) :
2023-02-17 14:33:27 -05:00
do_sample = gr . Checkbox ( value = generate_params [ ' do_sample ' ] , label = " do_sample " )
temperature = gr . Slider ( 0.01 , 1.99 , value = generate_params [ ' temperature ' ] , step = 0.01 , label = " temperature " )
with gr . Row ( ) :
top_k = gr . Slider ( 0 , 200 , value = generate_params [ ' top_k ' ] , step = 1 , label = " top_k " )
top_p = gr . Slider ( 0.0 , 1.0 , value = generate_params [ ' top_p ' ] , step = 0.01 , label = " top_p " )
with gr . Row ( ) :
repetition_penalty = gr . Slider ( 1.0 , 4.99 , value = generate_params [ ' repetition_penalty ' ] , step = 0.01 , label = " repetition_penalty " )
no_repeat_ngram_size = gr . Slider ( 0 , 20 , step = 1 , value = generate_params [ " no_repeat_ngram_size " ] , label = " no_repeat_ngram_size " )
with gr . Row ( ) :
typical_p = gr . Slider ( 0.0 , 1.0 , value = generate_params [ ' typical_p ' ] , step = 0.01 , label = " typical_p " )
2023-02-23 10:05:25 -05:00
min_length = gr . Slider ( 0 , 2000 , step = 1 , value = generate_params [ " min_length " ] if shared . args . no_stream else 0 , label = " min_length " , interactive = shared . args . no_stream )
2023-02-17 14:18:01 -05:00
gr . Markdown ( " Contrastive search: " )
penalty_alpha = gr . Slider ( 0 , 5 , value = generate_params [ " penalty_alpha " ] , label = " penalty_alpha " )
gr . Markdown ( " Beam search (uses a lot of VRAM): " )
2023-02-17 14:33:27 -05:00
with gr . Row ( ) :
num_beams = gr . Slider ( 1 , 20 , step = 1 , value = generate_params [ " num_beams " ] , label = " num_beams " )
length_penalty = gr . Slider ( - 5 , 5 , value = generate_params [ " length_penalty " ] , label = " length_penalty " )
2023-02-17 14:18:01 -05:00
early_stopping = gr . Checkbox ( value = generate_params [ " early_stopping " ] , label = " early_stopping " )
2023-02-07 20:08:21 -05:00
2023-02-16 19:55:20 -05:00
with gr . Accordion ( " Soft prompt " , open = False , elem_id = " accordion " ) :
2023-02-13 13:25:16 -05:00
with gr . Row ( ) :
softprompts_menu = gr . Dropdown ( choices = available_softprompts , value = " None " , label = ' Soft prompt ' )
create_refresh_button ( softprompts_menu , lambda : None , lambda : { " choices " : get_available_softprompts ( ) } , " refresh-button " )
2023-02-13 21:34:04 -05:00
gr . Markdown ( ' Upload a soft prompt (.zip format): ' )
2023-02-13 13:25:16 -05:00
with gr . Row ( ) :
2023-02-17 08:17:15 -05:00
upload_softprompt = gr . File ( type = ' binary ' , file_types = [ " .zip " ] )
2023-02-13 13:25:16 -05:00
2023-02-12 07:36:27 -05:00
model_menu . change ( load_model_wrapper , [ model_menu ] , [ model_menu ] , show_progress = True )
2023-02-11 12:48:12 -05:00
preset_menu . change ( load_preset_values , [ preset_menu ] , [ do_sample , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , penalty_alpha , length_penalty , early_stopping ] )
2023-02-13 13:25:16 -05:00
softprompts_menu . change ( load_soft_prompt , [ softprompts_menu ] , [ softprompts_menu ] , show_progress = True )
2023-02-17 08:27:41 -05:00
upload_softprompt . upload ( upload_soft_prompt , [ upload_softprompt ] , [ softprompts_menu ] )
2023-02-11 12:48:12 -05:00
return preset_menu , do_sample , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , penalty_alpha , length_penalty , early_stopping
2023-02-07 20:08:21 -05:00
# Global variables
2023-01-21 22:49:59 -05:00
available_models = get_available_models ( )
available_presets = get_available_presets ( )
available_characters = get_available_characters ( )
2023-02-23 10:05:25 -05:00
extensions_module . available_extensions = get_available_extensions ( )
2023-02-13 13:25:16 -05:00
available_softprompts = get_available_softprompts ( )
2023-02-23 10:05:25 -05:00
if shared . args . extensions is not None :
load_extensions ( )
2023-01-21 22:49:59 -05:00
2023-01-06 17:56:44 -05:00
# Choosing the default model
2023-02-23 10:05:25 -05:00
if shared . args . model is not None :
shared . model_name = shared . args . model
2023-01-06 17:56:44 -05:00
else :
2023-01-06 20:05:37 -05:00
if len ( available_models ) == 0 :
2023-01-06 17:56:44 -05:00
print ( " No models are available! Please download at least one. " )
2023-01-30 12:17:12 -05:00
sys . exit ( 0 )
2023-01-06 17:56:44 -05:00
elif len ( available_models ) == 1 :
i = 0
else :
print ( " The following models are available: \n " )
for i , model in enumerate ( available_models ) :
print ( f " { i + 1 } . { model } " )
print ( f " \n Which one do you want to load? 1- { len ( available_models ) } \n " )
i = int ( input ( ) ) - 1
2023-01-09 10:56:54 -05:00
print ( )
2023-02-23 10:05:25 -05:00
shared . model_name = available_models [ i ]
shared . model , shared . tokenizer = load_model ( shared . model_name )
2023-02-16 19:21:45 -05:00
loaded_preset = None
2023-01-06 17:56:44 -05:00
2023-01-08 18:10:31 -05:00
# UI settings
2023-02-23 10:05:25 -05:00
if shared . model_name . lower ( ) . startswith ( ( ' gpt4chan ' , ' gpt-4chan ' , ' 4chan ' ) ) :
2023-02-12 07:46:34 -05:00
default_text = settings [ ' prompt_gpt4chan ' ]
2023-02-23 10:05:25 -05:00
elif re . match ( ' (rosey|chip|joi)_.*_instruct.* ' , shared . model_name . lower ( ) ) is not None :
2023-02-12 07:46:34 -05:00
default_text = ' User: \n '
else :
default_text = settings [ ' prompt ' ]
2023-01-15 13:23:41 -05:00
description = f " \n \n # Text generation lab \n Generate text using Large Language Models. \n "
2023-02-15 18:20:56 -05:00
2023-02-23 10:05:25 -05:00
suffix = ' _pygmalion ' if ' pygmalion ' in shared . model_name . lower ( ) else ' '
2023-01-29 10:02:44 -05:00
buttons = { }
2023-01-29 12:27:22 -05:00
gen_events = [ ]
2023-01-19 12:03:47 -05:00
2023-02-23 10:05:25 -05:00
if shared . args . chat or shared . args . cai_chat :
2023-02-15 09:30:38 -05:00
if Path ( f ' logs/persistent.json ' ) . exists ( ) :
2023-02-23 10:05:25 -05:00
chat . load_history ( open ( Path ( f ' logs/persistent.json ' ) , ' rb ' ) . read ( ) , settings [ f ' name1 { suffix } ' ] , settings [ f ' name2 { suffix } ' ] )
2023-02-15 09:30:38 -05:00
2023-02-15 10:58:11 -05:00
with gr . Blocks ( css = css + chat_css , analytics_enabled = False ) as interface :
2023-02-23 10:05:25 -05:00
if shared . args . cai_chat :
display = gr . HTML ( value = generate_chat_html ( chat . history [ ' visible ' ] , settings [ f ' name1 { suffix } ' ] , settings [ f ' name2 { suffix } ' ] , chat . character ) )
2023-01-15 16:16:46 -05:00
else :
2023-02-23 10:05:25 -05:00
display = gr . Chatbot ( value = chat . history [ ' visible ' ] )
2023-01-23 12:04:01 -05:00
textbox = gr . Textbox ( label = ' Input ' )
2023-01-09 15:23:43 -05:00
with gr . Row ( ) :
2023-01-29 10:02:44 -05:00
buttons [ " Stop " ] = gr . Button ( " Stop " )
2023-02-04 20:53:42 -05:00
buttons [ " Generate " ] = gr . Button ( " Generate " )
2023-01-29 10:02:44 -05:00
buttons [ " Regenerate " ] = gr . Button ( " Regenerate " )
2023-02-04 20:53:42 -05:00
with gr . Row ( ) :
buttons [ " Impersonate " ] = gr . Button ( " Impersonate " )
2023-01-29 10:02:44 -05:00
buttons [ " Remove last " ] = gr . Button ( " Remove last " )
2023-02-15 14:49:52 -05:00
buttons [ " Clear history " ] = gr . Button ( " Clear history " )
2023-01-29 10:02:44 -05:00
with gr . Row ( ) :
buttons [ " Send last reply to input " ] = gr . Button ( " Send last reply to input " )
buttons [ " Replace last reply " ] = gr . Button ( " Replace last reply " )
2023-02-23 10:05:25 -05:00
if shared . args . picture :
2023-02-14 18:38:21 -05:00
with gr . Row ( ) :
2023-02-14 21:55:46 -05:00
picture_select = gr . Image ( label = " Send a picture " , type = ' pil ' )
2023-01-13 13:02:17 -05:00
2023-02-15 18:55:32 -05:00
with gr . Tab ( " Chat settings " ) :
name1 = gr . Textbox ( value = settings [ f ' name1 { suffix } ' ] , lines = 1 , label = ' Your name ' )
name2 = gr . Textbox ( value = settings [ f ' name2 { suffix } ' ] , lines = 1 , label = ' Bot \' s name ' )
context = gr . Textbox ( value = settings [ f ' context { suffix } ' ] , lines = 2 , label = ' Context ' )
with gr . Row ( ) :
character_menu = gr . Dropdown ( choices = available_characters , value = " None " , label = ' Character ' )
create_refresh_button ( character_menu , lambda : None , lambda : { " choices " : get_available_characters ( ) } , " refresh-button " )
2023-01-15 16:16:46 -05:00
2023-02-15 18:55:32 -05:00
with gr . Row ( ) :
check = gr . Checkbox ( value = settings [ f ' stop_at_newline { suffix } ' ] , label = ' Stop generating at new line character? ' )
with gr . Row ( ) :
with gr . Tab ( ' Chat history ' ) :
with gr . Row ( ) :
with gr . Column ( ) :
gr . Markdown ( ' Upload ' )
2023-02-17 08:27:41 -05:00
upload_chat_history = gr . File ( type = ' binary ' , file_types = [ " .json " , " .txt " ] )
2023-02-15 18:55:32 -05:00
with gr . Column ( ) :
gr . Markdown ( ' Download ' )
download = gr . File ( )
buttons [ " Download " ] = gr . Button ( value = " Click me " )
with gr . Tab ( ' Upload character ' ) :
with gr . Row ( ) :
with gr . Column ( ) :
gr . Markdown ( ' 1. Select the JSON file ' )
2023-02-17 08:17:15 -05:00
upload_char = gr . File ( type = ' binary ' , file_types = [ " .json " ] )
2023-02-15 18:55:32 -05:00
with gr . Column ( ) :
gr . Markdown ( ' 2. Select your character \' s profile picture (optional) ' )
2023-02-17 08:17:15 -05:00
upload_img = gr . File ( type = ' binary ' , file_types = [ " image " ] )
2023-02-15 18:55:32 -05:00
buttons [ " Upload character " ] = gr . Button ( value = " Submit " )
with gr . Tab ( ' Upload your profile picture ' ) :
2023-02-17 08:17:15 -05:00
upload_img_me = gr . File ( type = ' binary ' , file_types = [ " image " ] )
2023-02-15 18:55:32 -05:00
with gr . Tab ( ' Upload TavernAI Character Card ' ) :
2023-02-17 08:17:15 -05:00
upload_img_tavern = gr . File ( type = ' binary ' , file_types = [ " image " ] )
2023-02-15 18:55:32 -05:00
with gr . Tab ( " Generation settings " ) :
with gr . Row ( ) :
with gr . Column ( ) :
max_new_tokens = gr . Slider ( minimum = settings [ ' max_new_tokens_min ' ] , maximum = settings [ ' max_new_tokens_max ' ] , step = 1 , label = ' max_new_tokens ' , value = settings [ ' max_new_tokens ' ] )
with gr . Column ( ) :
chat_prompt_size_slider = gr . Slider ( minimum = settings [ ' chat_prompt_size_min ' ] , maximum = settings [ ' chat_prompt_size_max ' ] , step = 1 , label = ' Maximum prompt size in tokens ' , value = settings [ ' chat_prompt_size ' ] )
2023-01-21 22:02:46 -05:00
2023-02-15 18:55:32 -05:00
preset_menu , do_sample , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , penalty_alpha , length_penalty , early_stopping = create_settings_menus ( )
2023-01-15 16:16:46 -05:00
2023-02-23 10:05:25 -05:00
if shared . args . extensions is not None :
2023-02-15 18:55:32 -05:00
with gr . Tab ( " Extensions " ) :
create_extensions_block ( )
2023-01-28 21:00:51 -05:00
2023-02-15 08:18:50 -05:00
input_params = [ textbox , max_new_tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , penalty_alpha , length_penalty , early_stopping , name1 , name2 , context , check , chat_prompt_size_slider ]
2023-02-23 10:05:25 -05:00
if shared . args . picture :
2023-02-14 21:55:46 -05:00
input_params . append ( picture_select )
2023-02-23 10:05:25 -05:00
function_call = " chat.cai_chatbot_wrapper " if shared . args . cai_chat else " chat.chatbot_wrapper "
gen_events . append ( buttons [ " Generate " ] . click ( eval ( function_call ) , input_params , display , show_progress = shared . args . no_stream , api_name = " textgen " ) )
gen_events . append ( textbox . submit ( eval ( function_call ) , input_params , display , show_progress = shared . args . no_stream ) )
if shared . args . picture :
picture_select . upload ( eval ( function_call ) , input_params , display , show_progress = shared . args . no_stream )
gen_events . append ( buttons [ " Regenerate " ] . click ( chat . regenerate_wrapper , input_params , display , show_progress = shared . args . no_stream ) )
gen_events . append ( buttons [ " Impersonate " ] . click ( chat . impersonate_wrapper , input_params , textbox , show_progress = shared . args . no_stream ) )
buttons [ " Stop " ] . click ( chat . stop_everything_event , [ ] , [ ] , cancels = gen_events )
buttons [ " Send last reply to input " ] . click ( chat . send_last_reply_to_input , [ ] , textbox , show_progress = shared . args . no_stream )
buttons [ " Replace last reply " ] . click ( chat . replace_last_reply , [ textbox , name1 , name2 ] , display , show_progress = shared . args . no_stream )
buttons [ " Clear history " ] . click ( chat . clear_chat_log , [ character_menu , name1 , name2 ] , display )
buttons [ " Remove last " ] . click ( chat . remove_last_message , [ name1 , name2 ] , [ display , textbox ] , show_progress = False )
buttons [ " Download " ] . click ( chat . save_history , inputs = [ ] , outputs = [ download ] )
buttons [ " Upload character " ] . click ( chat . upload_character , [ upload_char , upload_img ] , [ character_menu ] )
2023-02-15 09:38:44 -05:00
2023-02-15 10:46:11 -05:00
# Clearing stuff and saving the history
2023-01-29 10:02:44 -05:00
for i in [ " Generate " , " Regenerate " , " Replace last reply " ] :
buttons [ i ] . click ( lambda x : " " , textbox , textbox , show_progress = False )
2023-02-23 10:05:25 -05:00
buttons [ i ] . click ( lambda : chat . save_history ( timestamp = False ) , [ ] , [ ] , show_progress = False )
buttons [ " Clear history " ] . click ( lambda : chat . save_history ( timestamp = False ) , [ ] , [ ] , show_progress = False )
2023-01-07 23:33:45 -05:00
textbox . submit ( lambda x : " " , textbox , textbox , show_progress = False )
2023-02-23 10:05:25 -05:00
textbox . submit ( lambda : chat . save_history ( timestamp = False ) , [ ] , [ ] , show_progress = False )
2023-02-15 09:38:44 -05:00
2023-02-23 10:05:25 -05:00
character_menu . change ( chat . load_character , [ character_menu , name1 , name2 ] , [ name2 , context , display ] )
upload_chat_history . upload ( chat . load_history , [ upload_chat_history , name1 , name2 ] , [ ] )
upload_img_tavern . upload ( chat . upload_tavern_character , [ upload_img_tavern , name1 , name2 ] , [ character_menu ] )
upload_img_me . upload ( chat . upload_your_profile_picture , [ upload_img_me ] , [ ] )
if shared . args . picture :
2023-02-17 08:27:41 -05:00
picture_select . upload ( lambda : None , [ ] , [ picture_select ] , show_progress = False )
2023-02-23 10:05:25 -05:00
if shared . args . cai_chat :
upload_chat_history . upload ( chat . redraw_html , [ name1 , name2 ] , [ display ] )
upload_img_me . upload ( chat . redraw_html , [ name1 , name2 ] , [ display ] )
2023-01-19 13:05:42 -05:00
else :
2023-02-23 10:05:25 -05:00
upload_chat_history . upload ( lambda : chat . history [ ' visible ' ] , [ ] , [ display ] )
upload_img_me . upload ( lambda : chat . history [ ' visible ' ] , [ ] , [ display ] )
2023-01-19 12:03:47 -05:00
2023-02-23 10:05:25 -05:00
elif shared . args . notebook :
2023-01-18 20:44:47 -05:00
with gr . Blocks ( css = css , analytics_enabled = False ) as interface :
gr . Markdown ( description )
with gr . Tab ( ' Raw ' ) :
textbox = gr . Textbox ( value = default_text , lines = 23 )
with gr . Tab ( ' Markdown ' ) :
markdown = gr . Markdown ( )
with gr . Tab ( ' HTML ' ) :
html = gr . HTML ( )
2023-02-07 20:08:21 -05:00
2023-01-29 10:02:44 -05:00
buttons [ " Generate " ] = gr . Button ( " Generate " )
buttons [ " Stop " ] = gr . Button ( " Stop " )
2023-01-10 23:33:57 -05:00
2023-02-07 20:08:21 -05:00
max_new_tokens = gr . Slider ( minimum = settings [ ' max_new_tokens_min ' ] , maximum = settings [ ' max_new_tokens_max ' ] , step = 1 , label = ' max_new_tokens ' , value = settings [ ' max_new_tokens ' ] )
2023-02-11 12:48:12 -05:00
preset_menu , do_sample , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , penalty_alpha , length_penalty , early_stopping = create_settings_menus ( )
2023-01-18 20:44:47 -05:00
2023-02-23 10:05:25 -05:00
if shared . args . extensions is not None :
2023-01-29 10:05:18 -05:00
create_extensions_block ( )
2023-01-29 07:48:18 -05:00
2023-02-23 10:05:25 -05:00
gen_events . append ( buttons [ " Generate " ] . click ( generate_reply , [ textbox , max_new_tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , penalty_alpha , length_penalty , early_stopping ] , [ textbox , markdown , html ] , show_progress = shared . args . no_stream , api_name = " textgen " ) )
gen_events . append ( textbox . submit ( generate_reply , [ textbox , max_new_tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , penalty_alpha , length_penalty , early_stopping ] , [ textbox , markdown , html ] , show_progress = shared . args . no_stream ) )
2023-01-29 12:27:22 -05:00
buttons [ " Stop " ] . click ( None , None , None , cancels = gen_events )
2023-01-18 20:44:47 -05:00
else :
2023-01-08 18:10:31 -05:00
with gr . Blocks ( css = css , analytics_enabled = False ) as interface :
gr . Markdown ( description )
2023-01-06 20:05:37 -05:00
with gr . Row ( ) :
with gr . Column ( ) :
textbox = gr . Textbox ( value = default_text , lines = 15 , label = ' Input ' )
2023-02-07 20:08:21 -05:00
max_new_tokens = gr . Slider ( minimum = settings [ ' max_new_tokens_min ' ] , maximum = settings [ ' max_new_tokens_max ' ] , step = 1 , label = ' max_new_tokens ' , value = settings [ ' max_new_tokens ' ] )
2023-01-29 10:02:44 -05:00
buttons [ " Generate " ] = gr . Button ( " Generate " )
2023-01-18 20:44:47 -05:00
with gr . Row ( ) :
with gr . Column ( ) :
2023-01-29 12:27:22 -05:00
buttons [ " Continue " ] = gr . Button ( " Continue " )
2023-01-18 20:44:47 -05:00
with gr . Column ( ) :
2023-01-29 10:02:44 -05:00
buttons [ " Stop " ] = gr . Button ( " Stop " )
2023-02-07 20:08:21 -05:00
2023-02-11 12:48:12 -05:00
preset_menu , do_sample , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , penalty_alpha , length_penalty , early_stopping = create_settings_menus ( )
2023-02-23 10:05:25 -05:00
if shared . args . extensions is not None :
2023-01-29 10:05:18 -05:00
create_extensions_block ( )
2023-01-29 07:48:18 -05:00
2023-01-06 20:05:37 -05:00
with gr . Column ( ) :
with gr . Tab ( ' Raw ' ) :
2023-01-10 23:36:11 -05:00
output_textbox = gr . Textbox ( lines = 15 , label = ' Output ' )
2023-01-06 20:05:37 -05:00
with gr . Tab ( ' Markdown ' ) :
markdown = gr . Markdown ( )
2023-01-06 21:14:08 -05:00
with gr . Tab ( ' HTML ' ) :
html = gr . HTML ( )
2023-01-06 20:05:37 -05:00
2023-02-23 10:05:25 -05:00
gen_events . append ( buttons [ " Generate " ] . click ( generate_reply , [ textbox , max_new_tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , penalty_alpha , length_penalty , early_stopping ] , [ output_textbox , markdown , html ] , show_progress = shared . args . no_stream , api_name = " textgen " ) )
gen_events . append ( textbox . submit ( generate_reply , [ textbox , max_new_tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , penalty_alpha , length_penalty , early_stopping ] , [ output_textbox , markdown , html ] , show_progress = shared . args . no_stream ) )
gen_events . append ( buttons [ " Continue " ] . click ( generate_reply , [ output_textbox , max_new_tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , penalty_alpha , length_penalty , early_stopping ] , [ output_textbox , markdown , html ] , show_progress = shared . args . no_stream ) )
2023-01-29 12:27:22 -05:00
buttons [ " Stop " ] . click ( None , None , None , cancels = gen_events )
2022-12-21 11:27:31 -05:00
2023-01-25 14:10:35 -05:00
interface . queue ( )
2023-02-23 10:05:25 -05:00
if shared . args . listen :
interface . launch ( prevent_thread_lock = True , share = shared . args . share , server_name = " 0.0.0.0 " , server_port = shared . args . listen_port )
2023-01-20 21:45:16 -05:00
else :
2023-02-23 10:05:25 -05:00
interface . launch ( prevent_thread_lock = True , share = shared . args . share , server_port = shared . args . listen_port )
2023-02-07 20:08:21 -05:00
# I think that I will need this later
while True :
time . sleep ( 0.5 )