2022-12-21 11:27:31 -05:00
import re
2023-01-21 22:02:46 -05:00
import gc
2023-01-05 23:33:21 -05:00
import time
import glob
2022-12-21 11:27:31 -05:00
import torch
2023-01-06 17:56:44 -05:00
import argparse
2023-01-15 13:23:41 -05:00
import json
2023-01-21 22:02:46 -05:00
from sys import exit
2023-01-07 14:33:43 -05:00
from pathlib import Path
2022-12-21 11:27:31 -05:00
import gradio as gr
2023-01-14 22:39:51 -05:00
import warnings
2023-01-19 10:20:57 -05:00
from tqdm import tqdm
2023-01-21 22:02:46 -05:00
import transformers
from transformers import AutoTokenizer , AutoModelForCausalLM
from modules . html_generator import *
from modules . ui import *
2022-12-21 11:27:31 -05:00
2023-01-15 13:23:41 -05:00
transformers . logging . set_verbosity_error ( )
2023-01-06 17:56:44 -05:00
parser = argparse . ArgumentParser ( )
2023-01-06 18:22:26 -05:00
parser . add_argument ( ' --model ' , type = str , help = ' Name of the model to load by default. ' )
2023-01-16 08:10:09 -05:00
parser . add_argument ( ' --notebook ' , action = ' store_true ' , help = ' Launch the web UI in notebook mode, where the output is written to the same text box as the input. ' )
parser . add_argument ( ' --chat ' , action = ' store_true ' , help = ' Launch the web UI in chat mode. ' )
parser . add_argument ( ' --cai-chat ' , action = ' store_true ' , help = ' Launch the web UI in chat mode with a style similar to Character.AI \' s. If the file profile.png or profile.jpg exists in the same folder as server.py, this image will be used as the bot \' s profile picture. ' )
2023-01-09 08:58:46 -05:00
parser . add_argument ( ' --cpu ' , action = ' store_true ' , help = ' Use the CPU to generate text. ' )
2023-01-10 21:16:33 -05:00
parser . add_argument ( ' --load-in-8bit ' , action = ' store_true ' , help = ' Load the model with 8-bit precision. ' )
2023-01-19 09:09:24 -05:00
parser . add_argument ( ' --auto-devices ' , action = ' store_true ' , help = ' Automatically split the model across the available GPU(s) and CPU. ' )
parser . add_argument ( ' --disk ' , action = ' store_true ' , help = ' If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. ' )
2023-01-21 00:48:06 -05:00
parser . add_argument ( ' --disk-cache-dir ' , type = str , help = ' Directory to save the disk cache to. Defaults to " cache/ " . ' )
2023-01-20 22:33:41 -05:00
parser . add_argument ( ' --gpu-memory ' , type = int , help = ' Maximum GPU memory in GiB to allocate. This is useful if you get out of memory errors while trying to generate text. Must be an integer number. ' )
2023-01-21 01:05:55 -05:00
parser . add_argument ( ' --cpu-memory ' , type = int , help = ' Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99. ' )
2023-01-22 14:19:11 -05:00
parser . add_argument ( ' --no-stream ' , action = ' store_true ' , help = ' Don \' t stream the text output in real time. This improves the text generation performance. ' )
2023-01-16 14:35:45 -05:00
parser . add_argument ( ' --settings ' , type = str , help = ' Load the default interface settings from this json file. See settings-template.json for an example. ' )
2023-01-20 21:45:16 -05:00
parser . add_argument ( ' --listen ' , action = ' store_true ' , help = ' Make the web UI reachable from your local network. ' )
2023-01-19 15:31:29 -05:00
parser . add_argument ( ' --share ' , action = ' store_true ' , help = ' Create a public URL. This is useful for running the web UI on Google Colab or similar. ' )
2023-01-06 17:56:44 -05:00
args = parser . parse_args ( )
2023-01-14 22:39:51 -05:00
2023-01-22 14:19:11 -05:00
if ( args . chat or args . cai_chat ) and not args . no_stream :
print ( " Warning: chat mode currently becomes a lot slower with text streaming on. \n Consider starting the web UI with the --no-stream option. \n " )
2023-01-15 13:23:41 -05:00
settings = {
' max_new_tokens ' : 200 ,
' max_new_tokens_min ' : 1 ,
' max_new_tokens_max ' : 2000 ,
' preset ' : ' NovelAI-Sphinx Moth ' ,
' name1 ' : ' Person 1 ' ,
' name2 ' : ' Person 2 ' ,
' context ' : ' This is a conversation between two people. ' ,
' prompt ' : ' Common sense questions and answers \n \n Question: \n Factual answer: ' ,
' prompt_gpt4chan ' : ' ----- \n --- 865467536 \n Input text \n --- 865467537 \n ' ,
' stop_at_newline ' : True ,
2023-01-22 15:17:35 -05:00
' history_size ' : 0 ,
2023-01-20 15:03:09 -05:00
' history_size_min ' : 0 ,
' history_size_max ' : 64 ,
2023-01-19 16:58:45 -05:00
' preset_pygmalion ' : ' Pygmalion ' ,
' name1_pygmalion ' : ' You ' ,
' name2_pygmalion ' : ' Kawaii ' ,
2023-01-21 22:49:59 -05:00
' context_pygmalion ' : " Kawaii ' s persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes. \n <START> " ,
2023-01-19 14:46:46 -05:00
' stop_at_newline_pygmalion ' : False ,
2023-01-15 13:23:41 -05:00
}
2023-01-16 14:35:45 -05:00
if args . settings is not None and Path ( args . settings ) . exists ( ) :
with open ( Path ( args . settings ) , ' r ' ) as f :
2023-01-15 13:23:41 -05:00
new_settings = json . load ( f )
2023-01-16 14:35:45 -05:00
for item in new_settings :
if item in settings :
settings [ item ] = new_settings [ item ]
2023-01-14 22:39:51 -05:00
2022-12-21 11:27:31 -05:00
def load_model ( model_name ) :
2023-01-05 23:41:52 -05:00
print ( f " Loading { model_name } ... " )
2022-12-21 11:27:31 -05:00
t0 = time . time ( )
2023-01-05 23:41:52 -05:00
2023-01-10 21:16:33 -05:00
# Default settings
2023-01-20 21:45:16 -05:00
if not ( args . cpu or args . load_in_8bit or args . auto_devices or args . disk or args . gpu_memory is not None ) :
2023-01-10 21:16:33 -05:00
if Path ( f " torch-dumps/ { model_name } .pt " ) . exists ( ) :
print ( " Loading in .pt format... " )
model = torch . load ( Path ( f " torch-dumps/ { model_name } .pt " ) )
elif model_name . lower ( ) . startswith ( ( ' gpt-neo ' , ' opt- ' , ' galactica ' ) ) and any ( size in model_name . lower ( ) for size in ( ' 13b ' , ' 20b ' , ' 30b ' ) ) :
model = AutoModelForCausalLM . from_pretrained ( Path ( f " models/ { model_name } " ) , device_map = ' auto ' , load_in_8bit = True )
else :
model = AutoModelForCausalLM . from_pretrained ( Path ( f " models/ { model_name } " ) , low_cpu_mem_usage = True , torch_dtype = torch . float16 ) . cuda ( )
# Custom
2023-01-06 00:54:33 -05:00
else :
2023-01-10 21:16:33 -05:00
settings = [ " low_cpu_mem_usage=True " ]
2023-01-10 21:39:50 -05:00
command = " AutoModelForCausalLM.from_pretrained "
2023-01-10 21:16:33 -05:00
2023-01-09 14:28:04 -05:00
if args . cpu :
2023-01-10 21:16:33 -05:00
settings . append ( " torch_dtype=torch.float32 " )
2023-01-09 14:28:04 -05:00
else :
2023-01-15 21:01:51 -05:00
settings . append ( " device_map= ' auto ' " )
2023-01-20 21:45:16 -05:00
if args . gpu_memory is not None :
2023-01-20 22:33:41 -05:00
if args . cpu_memory is not None :
settings . append ( f " max_memory= {{ 0: ' { args . gpu_memory } GiB ' , ' cpu ' : ' { args . cpu_memory } GiB ' }} " )
2023-01-20 17:05:43 -05:00
else :
2023-01-20 22:25:34 -05:00
settings . append ( f " max_memory= {{ 0: ' { args . gpu_memory } GiB ' , ' cpu ' : ' 99GiB ' }} " )
2023-01-19 09:09:24 -05:00
if args . disk :
2023-01-20 17:05:43 -05:00
if args . disk_cache_dir is not None :
2023-01-21 13:04:13 -05:00
settings . append ( f " offload_folder= ' { args . disk_cache_dir } ' " )
2023-01-20 17:05:43 -05:00
else :
settings . append ( " offload_folder= ' cache ' " )
2023-01-15 21:01:51 -05:00
if args . load_in_8bit :
2023-01-10 21:16:33 -05:00
settings . append ( " load_in_8bit=True " )
else :
settings . append ( " torch_dtype=torch.float16 " )
2023-01-16 14:35:45 -05:00
settings = ' , ' . join ( set ( settings ) )
2023-01-15 21:01:51 -05:00
command = f " { command } (Path(f ' models/ { model_name } ' ), { settings } ) "
2023-01-10 21:16:33 -05:00
model = eval ( command )
2022-12-21 11:27:31 -05:00
2023-01-06 00:54:33 -05:00
# Loading the tokenizer
2023-01-10 23:10:11 -05:00
if model_name . lower ( ) . startswith ( ( ' gpt4chan ' , ' gpt-4chan ' , ' 4chan ' ) ) and Path ( f " models/gpt-j-6B/ " ) . exists ( ) :
2023-01-07 14:33:43 -05:00
tokenizer = AutoTokenizer . from_pretrained ( Path ( " models/gpt-j-6B/ " ) )
2022-12-21 11:27:31 -05:00
else :
2023-01-07 14:33:43 -05:00
tokenizer = AutoTokenizer . from_pretrained ( Path ( f " models/ { model_name } / " ) )
2023-01-16 11:43:23 -05:00
tokenizer . truncation_side = ' left '
2022-12-21 11:27:31 -05:00
2023-01-06 00:06:59 -05:00
print ( f " Loaded the model in { ( time . time ( ) - t0 ) : .2f } seconds. " )
2022-12-21 11:27:31 -05:00
return model , tokenizer
2023-01-06 00:26:33 -05:00
# Removes empty replies from gpt4chan outputs
2022-12-21 11:27:31 -05:00
def fix_gpt4chan ( s ) :
for i in range ( 10 ) :
s = re . sub ( " --- [0-9]* \n >>[0-9]* \n --- " , " --- " , s )
s = re . sub ( " --- [0-9]* \n * \n --- " , " --- " , s )
s = re . sub ( " --- [0-9]* \n \n \n --- " , " --- " , s )
return s
2023-01-16 14:35:45 -05:00
# Fix the LaTeX equations in galactica
2023-01-06 23:56:21 -05:00
def fix_galactica ( s ) :
s = s . replace ( r ' \ [ ' , r ' $ ' )
s = s . replace ( r ' \ ] ' , r ' $ ' )
2023-01-07 10:13:09 -05:00
s = s . replace ( r ' \ ( ' , r ' $ ' )
s = s . replace ( r ' \ ) ' , r ' $ ' )
s = s . replace ( r ' $$ ' , r ' $ ' )
2023-01-06 23:56:21 -05:00
return s
2023-01-17 18:16:23 -05:00
def encode ( prompt , tokens ) :
2023-01-23 11:36:01 -05:00
if args . cpu :
input_ids = tokenizer . encode ( str ( prompt ) , return_tensors = ' pt ' , truncation = True , max_length = 2048 - tokens )
else :
2023-01-17 18:16:23 -05:00
torch . cuda . empty_cache ( )
input_ids = tokenizer . encode ( str ( prompt ) , return_tensors = ' pt ' , truncation = True , max_length = 2048 - tokens ) . cuda ( )
return input_ids
2023-01-19 08:43:05 -05:00
def decode ( output_ids ) :
reply = tokenizer . decode ( output_ids , skip_special_tokens = True )
reply = reply . replace ( r ' <|endoftext|> ' , ' ' )
return reply
def formatted_outputs ( reply , model_name ) :
2023-01-19 12:57:01 -05:00
if not ( args . chat or args . cai_chat ) :
if model_name . lower ( ) . startswith ( ' galactica ' ) :
reply = fix_galactica ( reply )
return reply , reply , generate_basic_html ( reply )
2023-01-21 20:13:01 -05:00
elif model_name . lower ( ) . startswith ( ( ' gpt4chan ' , ' gpt-4chan ' , ' 4chan ' ) ) :
2023-01-19 12:57:01 -05:00
reply = fix_gpt4chan ( reply )
return reply , ' Only applicable for GALACTICA models. ' , generate_4chan_html ( reply )
else :
return reply , ' Only applicable for GALACTICA models. ' , generate_basic_html ( reply )
2023-01-19 08:43:05 -05:00
else :
2023-01-19 12:57:01 -05:00
return reply
2023-01-19 08:43:05 -05:00
2023-01-13 12:28:53 -05:00
def generate_reply ( question , tokens , inference_settings , selected_model , eos_token = None ) :
2023-01-06 00:06:59 -05:00
global model , tokenizer , model_name , loaded_preset , preset
2022-12-21 11:27:31 -05:00
if selected_model != model_name :
model_name = selected_model
2023-01-20 21:45:16 -05:00
model = tokenizer = None
2023-01-09 08:58:46 -05:00
if not args . cpu :
2023-01-19 10:01:58 -05:00
gc . collect ( )
2023-01-09 08:58:46 -05:00
torch . cuda . empty_cache ( )
2022-12-21 11:27:31 -05:00
model , tokenizer = load_model ( model_name )
2023-01-06 00:06:59 -05:00
if inference_settings != loaded_preset :
2023-01-07 14:33:43 -05:00
with open ( Path ( f ' presets/ { inference_settings } .txt ' ) , ' r ' ) as infile :
2023-01-05 23:33:21 -05:00
preset = infile . read ( )
2023-01-06 00:06:59 -05:00
loaded_preset = inference_settings
2022-12-21 11:27:31 -05:00
2023-01-18 22:41:57 -05:00
cuda = " " if args . cpu else " .cuda() "
2023-01-19 08:43:05 -05:00
n = None if eos_token is None else tokenizer . encode ( eos_token , return_tensors = ' pt ' ) [ 0 ] [ - 1 ]
2023-01-19 22:45:02 -05:00
input_ids = encode ( question , tokens )
2023-01-19 08:43:05 -05:00
# Generate the entire reply at once
if args . no_stream :
2023-01-22 18:07:19 -05:00
t0 = time . time ( )
2023-01-19 08:43:05 -05:00
output = eval ( f " model.generate(input_ids, eos_token_id= { n } , { preset } ) { cuda } " )
reply = decode ( output [ 0 ] )
2023-01-22 18:07:19 -05:00
t1 = time . time ( )
print ( f " Output generated in { ( t1 - t0 ) : .2f } seconds ( { ( len ( output [ 0 ] ) - len ( input_ids [ 0 ] ) ) / ( t1 - t0 ) : .2f } it/s) " )
2023-01-19 08:43:05 -05:00
yield formatted_outputs ( reply , model_name )
# Generate the reply 1 token at a time
else :
2023-01-19 09:09:24 -05:00
yield formatted_outputs ( question , model_name )
2023-01-18 21:56:42 -05:00
preset = preset . replace ( ' max_new_tokens=tokens ' , ' max_new_tokens=1 ' )
2023-01-19 10:20:57 -05:00
for i in tqdm ( range ( tokens ) ) :
2023-01-18 21:56:42 -05:00
output = eval ( f " model.generate(input_ids, { preset } ) { cuda } " )
2023-01-19 08:43:05 -05:00
reply = decode ( output [ 0 ] )
2023-01-18 21:56:42 -05:00
if eos_token is not None and reply [ - 1 ] == eos_token :
break
2023-01-19 08:43:05 -05:00
yield formatted_outputs ( reply , model_name )
2023-01-18 21:56:42 -05:00
input_ids = output
2023-01-18 19:37:21 -05:00
2023-01-21 22:49:59 -05:00
def get_available_models ( ) :
return sorted ( set ( [ item . replace ( ' .pt ' , ' ' ) for item in map ( lambda x : str ( x . name ) , list ( Path ( ' models/ ' ) . glob ( ' * ' ) ) + list ( Path ( ' torch-dumps/ ' ) . glob ( ' * ' ) ) ) if not item . endswith ( ' .txt ' ) ] ) , key = str . lower )
def get_available_presets ( ) :
return sorted ( set ( map ( lambda x : ' . ' . join ( str ( x . name ) . split ( ' . ' ) [ : - 1 ] ) , Path ( ' presets ' ) . glob ( ' *.txt ' ) ) ) , key = str . lower )
def get_available_characters ( ) :
return [ " None " ] + sorted ( set ( map ( lambda x : ' . ' . join ( str ( x . name ) . split ( ' . ' ) [ : - 1 ] ) , Path ( ' characters ' ) . glob ( ' *.json ' ) ) ) , key = str . lower )
available_models = get_available_models ( )
available_presets = get_available_presets ( )
available_characters = get_available_characters ( )
2023-01-06 17:56:44 -05:00
# Choosing the default model
if args . model is not None :
model_name = args . model
else :
2023-01-06 20:05:37 -05:00
if len ( available_models ) == 0 :
2023-01-06 17:56:44 -05:00
print ( " No models are available! Please download at least one. " )
exit ( 0 )
elif len ( available_models ) == 1 :
i = 0
else :
print ( " The following models are available: \n " )
for i , model in enumerate ( available_models ) :
print ( f " { i + 1 } . { model } " )
print ( f " \n Which one do you want to load? 1- { len ( available_models ) } \n " )
i = int ( input ( ) ) - 1
2023-01-09 10:56:54 -05:00
print ( )
2023-01-06 17:56:44 -05:00
model_name = available_models [ i ]
2022-12-21 11:27:31 -05:00
model , tokenizer = load_model ( model_name )
2023-01-21 22:49:59 -05:00
loaded_preset = None
2023-01-06 17:56:44 -05:00
2023-01-08 18:10:31 -05:00
# UI settings
2023-01-21 22:49:59 -05:00
default_text = settings [ ' prompt_gpt4chan ' ] if model_name . lower ( ) . startswith ( ( ' gpt4chan ' , ' gpt-4chan ' , ' 4chan ' ) ) else settings [ ' prompt ' ]
2023-01-15 13:23:41 -05:00
description = f " \n \n # Text generation lab \n Generate text using Large Language Models. \n "
2023-01-21 22:02:46 -05:00
css = " .my-4 { margin-top: 0} .py-6 { padding-top: 2.5rem} #refresh-button { flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label { min-height: 0} "
2023-01-21 22:49:59 -05:00
2023-01-18 20:44:47 -05:00
if args . chat or args . cai_chat :
2023-01-07 20:52:46 -05:00
history = [ ]
2023-01-19 14:46:46 -05:00
character = None
2023-01-07 20:52:46 -05:00
2023-01-14 22:39:51 -05:00
# This gets the new line characters right.
2023-01-18 17:06:50 -05:00
def clean_chat_message ( text ) :
2023-01-14 21:50:34 -05:00
text = text . replace ( ' \n ' , ' \n \n ' )
text = re . sub ( r " \ n { 3,} " , " \n \n " , text )
text = text . strip ( )
2023-01-14 22:39:51 -05:00
return text
2023-01-20 15:03:09 -05:00
def generate_chat_prompt ( text , tokens , name1 , name2 , context , history_size ) :
2023-01-18 17:06:50 -05:00
text = clean_chat_message ( text )
2023-01-14 21:50:34 -05:00
2023-01-19 23:54:38 -05:00
rows = [ f " { context . strip ( ) } \n " ]
2023-01-17 18:16:23 -05:00
i = len ( history ) - 1
2023-01-20 15:03:09 -05:00
count = 0
2023-01-17 18:16:23 -05:00
while i > = 0 and len ( encode ( ' ' . join ( rows ) , tokens ) [ 0 ] ) < 2048 - tokens :
rows . insert ( 1 , f " { name2 } : { history [ i ] [ 1 ] . strip ( ) } \n " )
2023-01-20 15:03:09 -05:00
count + = 1
2023-01-21 22:35:42 -05:00
if not ( history [ i ] [ 0 ] == ' <|BEGIN-VISIBLE-CHAT|> ' ) :
2023-01-19 23:54:38 -05:00
rows . insert ( 1 , f " { name1 } : { history [ i ] [ 0 ] . strip ( ) } \n " )
2023-01-20 15:03:09 -05:00
count + = 1
2023-01-17 18:16:23 -05:00
i - = 1
2023-01-20 15:03:09 -05:00
if history_size != 0 and count > = history_size :
break
2023-01-17 18:16:23 -05:00
rows . append ( f " { name1 } : { text } \n " )
rows . append ( f " { name2 } : " )
while len ( rows ) > 3 and len ( encode ( ' ' . join ( rows ) , tokens ) [ 0 ] ) > = 2048 - tokens :
rows . pop ( 1 )
rows . pop ( 1 )
question = ' ' . join ( rows )
2023-01-18 17:06:50 -05:00
return question
2023-01-07 20:52:46 -05:00
2023-01-21 13:04:13 -05:00
def remove_example_dialogue_from_history ( history ) :
_history = copy . deepcopy ( history )
for i in range ( len ( _history ) ) :
if ' <|BEGIN-VISIBLE-CHAT|> ' in _history [ i ] [ 0 ] :
_history [ i ] [ 0 ] = _history [ i ] [ 0 ] . replace ( ' <|BEGIN-VISIBLE-CHAT|> ' , ' ' )
_history = _history [ i : ]
break
return _history
2023-01-20 15:03:09 -05:00
def chatbot_wrapper ( text , tokens , inference_settings , selected_model , name1 , name2 , context , check , history_size ) :
question = generate_chat_prompt ( text , tokens , name1 , name2 , context , history_size )
2023-01-18 20:08:23 -05:00
history . append ( [ ' ' , ' ' ] )
2023-01-18 17:06:50 -05:00
eos_token = ' \n ' if check else None
2023-01-19 12:57:01 -05:00
for reply in generate_reply ( question , tokens , inference_settings , selected_model , eos_token = eos_token ) :
2023-01-18 19:51:18 -05:00
next_character_found = False
2023-01-18 17:06:50 -05:00
2023-01-19 17:59:34 -05:00
previous_idx = [ m . start ( ) for m in re . finditer ( f " (^| \n ) { name2 } : " , question ) ]
2023-01-19 12:57:01 -05:00
idx = [ m . start ( ) for m in re . finditer ( f " (^| \n ) { name2 } : " , reply ) ]
idx = idx [ len ( previous_idx ) - 1 ]
2023-01-19 17:59:34 -05:00
2023-01-19 12:57:01 -05:00
reply = reply [ idx + len ( f " \n { name2 } : " ) : ]
2023-01-18 17:06:50 -05:00
if check :
2023-01-19 12:57:01 -05:00
reply = reply . split ( ' \n ' ) [ 0 ] . strip ( )
2023-01-18 17:06:50 -05:00
else :
idx = reply . find ( f " \n { name1 } : " )
if idx != - 1 :
reply = reply [ : idx ]
2023-01-18 19:51:18 -05:00
next_character_found = True
2023-01-18 17:06:50 -05:00
reply = clean_chat_message ( reply )
history [ - 1 ] = [ text , reply ]
2023-01-19 08:43:05 -05:00
if next_character_found :
break
2023-01-18 17:06:50 -05:00
# Prevent the chat log from flashing if something like "\nYo" is generated just
# before "\nYou:" is completed
tmp = f " \n { name1 } : "
2023-01-18 19:51:18 -05:00
next_character_substring_found = False
2023-01-19 08:43:05 -05:00
for j in range ( 1 , len ( tmp ) ) :
2023-01-18 17:06:50 -05:00
if reply [ - j : ] == tmp [ : j ] :
2023-01-18 19:51:18 -05:00
next_character_substring_found = True
2023-01-18 17:06:50 -05:00
2023-01-18 19:51:18 -05:00
if not next_character_substring_found :
2023-01-21 13:04:13 -05:00
yield remove_example_dialogue_from_history ( history )
2023-01-07 20:52:46 -05:00
2023-01-21 13:04:13 -05:00
yield remove_example_dialogue_from_history ( history )
2023-01-18 19:51:18 -05:00
2023-01-20 15:03:09 -05:00
def cai_chatbot_wrapper ( text , tokens , inference_settings , selected_model , name1 , name2 , context , check , history_size ) :
for history in chatbot_wrapper ( text , tokens , inference_settings , selected_model , name1 , name2 , context , check , history_size ) :
2023-01-19 14:46:46 -05:00
yield generate_chat_html ( history , name1 , name2 , character )
2023-01-15 10:20:04 -05:00
2023-01-22 00:19:58 -05:00
def regenerate_wrapper ( text , tokens , inference_settings , selected_model , name1 , name2 , context , check , history_size ) :
last = history . pop ( )
text = last [ 0 ]
if args . cai_chat :
for i in cai_chatbot_wrapper ( text , tokens , inference_settings , selected_model , name1 , name2 , context , check , history_size ) :
yield i
else :
for i in chatbot_wrapper ( text , tokens , inference_settings , selected_model , name1 , name2 , context , check , history_size ) :
yield i
2023-01-15 10:20:04 -05:00
def remove_last_message ( name1 , name2 ) :
2023-01-22 17:40:22 -05:00
last = history . pop ( )
2023-01-22 11:10:36 -05:00
_history = remove_example_dialogue_from_history ( history )
2023-01-15 10:20:04 -05:00
if args . cai_chat :
2023-01-22 17:40:22 -05:00
return generate_chat_html ( _history , name1 , name2 , character ) , last [ 0 ]
2023-01-15 10:20:04 -05:00
else :
2023-01-22 17:40:22 -05:00
return _history , last [ 0 ]
2023-01-15 01:19:09 -05:00
2023-01-07 20:52:46 -05:00
def clear ( ) :
global history
history = [ ]
2023-01-15 10:20:04 -05:00
def clear_html ( ) :
2023-01-19 14:46:46 -05:00
return generate_chat_html ( [ ] , " " , " " , character )
2023-01-15 10:20:04 -05:00
2023-01-19 12:03:47 -05:00
def redraw_html ( name1 , name2 ) :
global history
2023-01-22 00:32:54 -05:00
_history = remove_example_dialogue_from_history ( history )
return generate_chat_html ( _history , name1 , name2 , character )
2023-01-19 12:03:47 -05:00
2023-01-23 07:45:10 -05:00
def tokenize_dialogue ( dialogue , name1 , name2 ) :
2023-01-21 00:48:06 -05:00
dialogue = re . sub ( ' <START> ' , ' ' , dialogue )
dialogue = re . sub ( ' ( \n |^)[Aa]non: ' , ' \\ 1You: ' , dialogue )
idx = [ m . start ( ) for m in re . finditer ( f " (^| \n )( { name1 } | { name2 } ): " , dialogue ) ]
messages = [ ]
for i in range ( len ( idx ) - 1 ) :
messages . append ( dialogue [ idx [ i ] : idx [ i + 1 ] ] . strip ( ) )
history = [ ]
entry = [ ' ' , ' ' ]
for i in messages :
if i . startswith ( f ' { name1 } : ' ) :
entry [ 0 ] = i [ len ( f ' { name1 } : ' ) : ] . strip ( )
elif i . startswith ( f ' { name2 } : ' ) :
entry [ 1 ] = i [ len ( f ' { name2 } : ' ) : ] . strip ( )
if not ( len ( entry [ 0 ] ) == 0 and len ( entry [ 1 ] ) == 0 ) :
history . append ( entry )
entry = [ ' ' , ' ' ]
return history
2023-01-23 07:45:10 -05:00
def save_history ( ) :
if not Path ( ' logs ' ) . exists ( ) :
Path ( ' logs ' ) . mkdir ( )
with open ( Path ( ' logs/conversation.json ' ) , ' w ' ) as f :
f . write ( json . dumps ( { ' data ' : history } , indent = 2 ) )
return Path ( ' logs/conversation.json ' )
def load_history ( file , name1 , name2 ) :
global history
file = file . decode ( ' utf-8 ' )
try :
history = json . loads ( file ) [ ' data ' ]
except :
history = tokenize_dialogue ( file , name1 , name2 )
2023-01-19 14:46:46 -05:00
def load_character ( _character , name1 , name2 ) :
global history , character
context = " "
history = [ ]
if _character != ' None ' :
character = _character
with open ( Path ( f ' characters/ { _character } .json ' ) , ' r ' ) as f :
data = json . loads ( f . read ( ) )
name2 = data [ ' char_name ' ]
if ' char_persona ' in data and data [ ' char_persona ' ] != ' ' :
context + = f " { data [ ' char_name ' ] } ' s Persona: { data [ ' char_persona ' ] } \n "
if ' world_scenario ' in data and data [ ' world_scenario ' ] != ' ' :
context + = f " Scenario: { data [ ' world_scenario ' ] } \n "
2023-01-19 17:04:54 -05:00
context = f " { context . strip ( ) } \n <START> \n "
2023-01-19 14:46:46 -05:00
if ' example_dialogue ' in data and data [ ' example_dialogue ' ] != ' ' :
2023-01-23 07:45:10 -05:00
history = tokenize_dialogue ( data [ ' example_dialogue ' ] , name1 , name2 )
2023-01-21 00:48:06 -05:00
if ' char_greeting ' in data and len ( data [ ' char_greeting ' ] . strip ( ) ) > 0 :
history + = [ [ ' <|BEGIN-VISIBLE-CHAT|> ' , data [ ' char_greeting ' ] ] ]
else :
history + = [ [ ' <|BEGIN-VISIBLE-CHAT|> ' , " Hello there! " ] ]
2023-01-19 14:46:46 -05:00
else :
character = None
context = settings [ ' context_pygmalion ' ]
name2 = settings [ ' name2_pygmalion ' ]
2023-01-21 13:04:13 -05:00
_history = remove_example_dialogue_from_history ( history )
2023-01-19 14:46:46 -05:00
if args . cai_chat :
2023-01-21 13:04:13 -05:00
return name2 , context , generate_chat_html ( _history , name1 , name2 , character )
2023-01-19 14:46:46 -05:00
else :
2023-01-21 13:04:13 -05:00
return name2 , context , _history
2023-01-19 14:46:46 -05:00
suffix = ' _pygmalion ' if ' pygmalion ' in model_name . lower ( ) else ' '
2023-01-15 16:16:46 -05:00
with gr . Blocks ( css = css + " .h- \ [40vh \ ] { height: 66.67vh} .gradio-container { max-width: 800px; margin-left: auto; margin-right: auto} " , analytics_enabled = False ) as interface :
if args . cai_chat :
2023-01-19 14:46:46 -05:00
display1 = gr . HTML ( value = generate_chat_html ( [ ] , " " , " " , character ) )
2023-01-15 16:16:46 -05:00
else :
display1 = gr . Chatbot ( )
textbox = gr . Textbox ( lines = 2 , label = ' Input ' )
btn = gr . Button ( " Generate " )
2023-01-09 15:23:43 -05:00
with gr . Row ( ) :
2023-01-18 20:44:47 -05:00
stop = gr . Button ( " Stop " )
2023-01-22 00:19:58 -05:00
btn_regenerate = gr . Button ( " Regenerate " )
btn_remove_last = gr . Button ( " Remove last " )
btn_clear = gr . Button ( " Clear history " )
2023-01-13 13:02:17 -05:00
2023-01-15 16:16:46 -05:00
with gr . Row ( ) :
2023-01-07 20:52:46 -05:00
with gr . Column ( ) :
2023-01-20 15:03:09 -05:00
length_slider = gr . Slider ( minimum = settings [ ' max_new_tokens_min ' ] , maximum = settings [ ' max_new_tokens_max ' ] , step = 1 , label = ' max_new_tokens ' , value = settings [ ' max_new_tokens ' ] )
2023-01-21 22:02:46 -05:00
with gr . Row ( ) :
model_menu = gr . Dropdown ( choices = available_models , value = model_name , label = ' Model ' )
create_refresh_button ( model_menu , lambda : None , lambda : { " choices " : get_available_models ( ) } , " refresh-button " )
2023-01-15 16:16:46 -05:00
with gr . Column ( ) :
2023-01-22 01:15:35 -05:00
history_size_slider = gr . Slider ( minimum = settings [ ' history_size_min ' ] , maximum = settings [ ' history_size_max ' ] , step = 1 , label = ' Chat history size in prompt (0 for no limit) ' , value = settings [ ' history_size ' ] )
2023-01-21 22:02:46 -05:00
with gr . Row ( ) :
preset_menu = gr . Dropdown ( choices = available_presets , value = settings [ f ' preset { suffix } ' ] , label = ' Settings preset ' )
create_refresh_button ( preset_menu , lambda : None , lambda : { " choices " : get_available_presets ( ) } , " refresh-button " )
2023-01-15 16:16:46 -05:00
2023-01-19 16:58:45 -05:00
name1 = gr . Textbox ( value = settings [ f ' name1 { suffix } ' ] , lines = 1 , label = ' Your name ' )
name2 = gr . Textbox ( value = settings [ f ' name2 { suffix } ' ] , lines = 1 , label = ' Bot \' s name ' )
context = gr . Textbox ( value = settings [ f ' context { suffix } ' ] , lines = 2 , label = ' Context ' )
2023-01-15 16:16:46 -05:00
with gr . Row ( ) :
2023-01-21 22:02:46 -05:00
character_menu = gr . Dropdown ( choices = available_characters , value = " None " , label = ' Character ' )
create_refresh_button ( character_menu , lambda : None , lambda : { " choices " : get_available_characters ( ) } , " refresh-button " )
2023-01-19 14:46:46 -05:00
with gr . Row ( ) :
2023-01-19 16:58:45 -05:00
check = gr . Checkbox ( value = settings [ f ' stop_at_newline { suffix } ' ] , label = ' Stop generating at new line character? ' )
2023-01-19 12:03:47 -05:00
with gr . Row ( ) :
2023-01-21 22:22:50 -05:00
with gr . Tab ( ' Download chat history ' ) :
2023-01-19 12:03:47 -05:00
download = gr . File ( )
2023-01-21 22:22:50 -05:00
save_btn = gr . Button ( value = " Click me " )
2023-01-21 23:24:16 -05:00
with gr . Tab ( ' Upload chat history ' ) :
upload = gr . File ( type = ' binary ' )
2023-01-15 16:16:46 -05:00
2023-01-20 15:03:09 -05:00
input_params = [ textbox , length_slider , preset_menu , model_menu , name1 , name2 , context , check , history_size_slider ]
2023-01-15 10:20:04 -05:00
if args . cai_chat :
2023-01-20 15:03:09 -05:00
gen_event = btn . click ( cai_chatbot_wrapper , input_params , display1 , show_progress = args . no_stream , api_name = " textgen " )
gen_event2 = textbox . submit ( cai_chatbot_wrapper , input_params , display1 , show_progress = args . no_stream )
2023-01-22 00:19:58 -05:00
btn_clear . click ( clear_html , [ ] , display1 , show_progress = False )
2023-01-15 10:20:04 -05:00
else :
2023-01-20 15:03:09 -05:00
gen_event = btn . click ( chatbot_wrapper , input_params , display1 , show_progress = args . no_stream , api_name = " textgen " )
gen_event2 = textbox . submit ( chatbot_wrapper , input_params , display1 , show_progress = args . no_stream )
2023-01-22 00:19:58 -05:00
btn_clear . click ( lambda x : " " , display1 , display1 , show_progress = False )
gen_event3 = btn_regenerate . click ( regenerate_wrapper , input_params , display1 , show_progress = args . no_stream )
2023-01-15 10:20:04 -05:00
2023-01-22 00:19:58 -05:00
btn_clear . click ( clear )
2023-01-22 17:40:22 -05:00
btn_remove_last . click ( remove_last_message , [ name1 , name2 ] , [ display1 , textbox ] , show_progress = False )
2023-01-07 23:10:02 -05:00
btn . click ( lambda x : " " , textbox , textbox , show_progress = False )
2023-01-22 00:19:58 -05:00
btn_regenerate . click ( lambda x : " " , textbox , textbox , show_progress = False )
2023-01-07 23:33:45 -05:00
textbox . submit ( lambda x : " " , textbox , textbox , show_progress = False )
2023-01-22 00:19:58 -05:00
stop . click ( None , None , None , cancels = [ gen_event , gen_event2 , gen_event3 ] )
2023-01-19 12:03:47 -05:00
save_btn . click ( save_history , inputs = [ ] , outputs = [ download ] )
2023-01-23 07:45:10 -05:00
upload . upload ( load_history , [ upload , name1 , name2 ] , [ ] )
2023-01-19 14:46:46 -05:00
character_menu . change ( load_character , [ character_menu , name1 , name2 ] , [ name2 , context , display1 ] )
2023-01-19 13:05:42 -05:00
if args . cai_chat :
upload . upload ( redraw_html , [ name1 , name2 ] , [ display1 ] )
else :
2023-01-22 00:32:54 -05:00
upload . upload ( lambda : remove_example_dialogue_from_history ( history ) , [ ] , [ display1 ] )
2023-01-19 12:03:47 -05:00
2023-01-18 20:44:47 -05:00
elif args . notebook :
with gr . Blocks ( css = css , analytics_enabled = False ) as interface :
gr . Markdown ( description )
with gr . Tab ( ' Raw ' ) :
textbox = gr . Textbox ( value = default_text , lines = 23 )
with gr . Tab ( ' Markdown ' ) :
markdown = gr . Markdown ( )
with gr . Tab ( ' HTML ' ) :
html = gr . HTML ( )
btn = gr . Button ( " Generate " )
stop = gr . Button ( " Stop " )
2023-01-10 23:33:57 -05:00
2023-01-18 20:44:47 -05:00
length_slider = gr . Slider ( minimum = settings [ ' max_new_tokens_min ' ] , maximum = settings [ ' max_new_tokens_max ' ] , step = 1 , label = ' max_new_tokens ' , value = settings [ ' max_new_tokens ' ] )
with gr . Row ( ) :
with gr . Column ( ) :
2023-01-21 22:02:46 -05:00
with gr . Row ( ) :
model_menu = gr . Dropdown ( choices = available_models , value = model_name , label = ' Model ' )
create_refresh_button ( model_menu , lambda : None , lambda : { " choices " : get_available_models ( ) } , " refresh-button " )
2023-01-18 20:44:47 -05:00
with gr . Column ( ) :
2023-01-21 22:02:46 -05:00
with gr . Row ( ) :
preset_menu = gr . Dropdown ( choices = available_presets , value = settings [ ' preset ' ] , label = ' Settings preset ' )
create_refresh_button ( preset_menu , lambda : None , lambda : { " choices " : get_available_presets ( ) } , " refresh-button " )
2023-01-18 20:44:47 -05:00
2023-01-18 21:56:42 -05:00
gen_event = btn . click ( generate_reply , [ textbox , length_slider , preset_menu , model_menu ] , [ textbox , markdown , html ] , show_progress = args . no_stream , api_name = " textgen " )
gen_event2 = textbox . submit ( generate_reply , [ textbox , length_slider , preset_menu , model_menu ] , [ textbox , markdown , html ] , show_progress = args . no_stream )
2023-01-18 20:44:47 -05:00
stop . click ( None , None , None , cancels = [ gen_event , gen_event2 ] )
else :
2023-01-08 18:10:31 -05:00
with gr . Blocks ( css = css , analytics_enabled = False ) as interface :
gr . Markdown ( description )
2023-01-06 20:05:37 -05:00
with gr . Row ( ) :
with gr . Column ( ) :
textbox = gr . Textbox ( value = default_text , lines = 15 , label = ' Input ' )
2023-01-15 13:23:41 -05:00
length_slider = gr . Slider ( minimum = settings [ ' max_new_tokens_min ' ] , maximum = settings [ ' max_new_tokens_max ' ] , step = 1 , label = ' max_new_tokens ' , value = settings [ ' max_new_tokens ' ] )
2023-01-21 22:02:46 -05:00
with gr . Row ( ) :
preset_menu = gr . Dropdown ( choices = available_presets , value = settings [ ' preset ' ] , label = ' Settings preset ' )
create_refresh_button ( preset_menu , lambda : None , lambda : { " choices " : get_available_presets ( ) } , " refresh-button " )
with gr . Row ( ) :
model_menu = gr . Dropdown ( choices = available_models , value = model_name , label = ' Model ' )
create_refresh_button ( model_menu , lambda : None , lambda : { " choices " : get_available_models ( ) } , " refresh-button " )
2023-01-06 20:05:37 -05:00
btn = gr . Button ( " Generate " )
2023-01-18 20:44:47 -05:00
with gr . Row ( ) :
with gr . Column ( ) :
cont = gr . Button ( " Continue " )
with gr . Column ( ) :
stop = gr . Button ( " Stop " )
2023-01-06 20:05:37 -05:00
with gr . Column ( ) :
with gr . Tab ( ' Raw ' ) :
2023-01-10 23:36:11 -05:00
output_textbox = gr . Textbox ( lines = 15 , label = ' Output ' )
2023-01-06 20:05:37 -05:00
with gr . Tab ( ' Markdown ' ) :
markdown = gr . Markdown ( )
2023-01-06 21:14:08 -05:00
with gr . Tab ( ' HTML ' ) :
html = gr . HTML ( )
2023-01-06 20:05:37 -05:00
2023-01-18 21:56:42 -05:00
gen_event = btn . click ( generate_reply , [ textbox , length_slider , preset_menu , model_menu ] , [ output_textbox , markdown , html ] , show_progress = args . no_stream , api_name = " textgen " )
gen_event2 = textbox . submit ( generate_reply , [ textbox , length_slider , preset_menu , model_menu ] , [ output_textbox , markdown , html ] , show_progress = args . no_stream )
cont_event = cont . click ( generate_reply , [ output_textbox , length_slider , preset_menu , model_menu ] , [ output_textbox , markdown , html ] , show_progress = args . no_stream )
2023-01-18 20:44:47 -05:00
stop . click ( None , None , None , cancels = [ gen_event , gen_event2 , cont_event ] )
2022-12-21 11:27:31 -05:00
2023-01-18 17:06:50 -05:00
interface . queue ( )
2023-01-20 21:45:16 -05:00
if args . listen :
2023-01-19 15:31:29 -05:00
interface . launch ( share = args . share , server_name = " 0.0.0.0 " )
2023-01-20 21:45:16 -05:00
else :
interface . launch ( share = args . share )