2022-12-21 11:27:31 -05:00
import re
2023-01-05 23:33:21 -05:00
import time
import glob
2023-01-06 17:56:44 -05:00
from sys import exit
2022-12-21 11:27:31 -05:00
import torch
2023-01-06 17:56:44 -05:00
import argparse
2023-01-07 14:33:43 -05:00
from pathlib import Path
2022-12-21 11:27:31 -05:00
import gradio as gr
import transformers
2023-01-06 21:14:08 -05:00
from html_generator import *
2023-01-07 14:33:43 -05:00
from transformers import AutoTokenizer , T5Tokenizer
from transformers import AutoModelForCausalLM , T5ForConditionalGeneration
2022-12-21 11:27:31 -05:00
2023-01-06 21:14:08 -05:00
2023-01-06 17:56:44 -05:00
parser = argparse . ArgumentParser ( )
2023-01-06 18:22:26 -05:00
parser . add_argument ( ' --model ' , type = str , help = ' Name of the model to load by default. ' )
parser . add_argument ( ' --notebook ' , action = ' store_true ' , help = ' Launch the webui in notebook mode, where the output is written to the same text box as the input. ' )
2023-01-06 17:56:44 -05:00
args = parser . parse_args ( )
2023-01-06 00:06:59 -05:00
loaded_preset = None
2023-01-07 14:33:43 -05:00
available_models = sorted ( set ( map ( lambda x : str ( x . name ) . replace ( ' .pt ' , ' ' ) , list ( Path ( ' models/ ' ) . glob ( ' * ' ) ) + list ( Path ( ' torch-dumps/ ' ) . glob ( ' * ' ) ) ) ) )
2023-01-06 21:14:08 -05:00
available_models = [ item for item in available_models if not item . endswith ( ' .txt ' ) ]
2023-01-07 14:33:43 -05:00
available_presets = sorted ( set ( map ( lambda x : str ( x . name ) . split ( ' . ' ) [ 0 ] , list ( Path ( ' presets ' ) . glob ( ' *.txt ' ) ) ) ) )
2023-01-05 23:33:21 -05:00
2022-12-21 11:27:31 -05:00
def load_model ( model_name ) :
2023-01-05 23:41:52 -05:00
print ( f " Loading { model_name } ... " )
2022-12-21 11:27:31 -05:00
t0 = time . time ( )
2023-01-05 23:41:52 -05:00
2023-01-06 00:54:33 -05:00
# Loading the model
2023-01-07 14:33:43 -05:00
if Path ( f " torch-dumps/ { model_name } .pt " ) . exists ( ) :
2023-01-05 23:41:52 -05:00
print ( " Loading in .pt format... " )
2023-01-07 14:33:43 -05:00
model = torch . load ( Path ( f " torch-dumps/ { model_name } .pt " ) ) . cuda ( )
2023-01-06 00:54:33 -05:00
elif model_name . lower ( ) . startswith ( ( ' gpt-neo ' , ' opt- ' , ' galactica ' ) ) :
2023-01-07 10:21:04 -05:00
if any ( size in model_name . lower ( ) for size in ( ' 13b ' , ' 20b ' , ' 30b ' ) ) :
2023-01-07 14:33:43 -05:00
model = AutoModelForCausalLM . from_pretrained ( Path ( f " models/ { model_name } " ) , device_map = ' auto ' , load_in_8bit = True )
2023-01-06 00:54:33 -05:00
else :
2023-01-07 14:33:43 -05:00
model = AutoModelForCausalLM . from_pretrained ( Path ( f " models/ { model_name } " ) , low_cpu_mem_usage = True , torch_dtype = torch . float16 ) . cuda ( )
2022-12-21 11:27:31 -05:00
elif model_name in [ ' gpt-j-6B ' ] :
2023-01-07 14:33:43 -05:00
model = AutoModelForCausalLM . from_pretrained ( Path ( f " models/ { model_name } " ) , low_cpu_mem_usage = True , torch_dtype = torch . float16 ) . cuda ( )
2023-01-06 00:06:59 -05:00
elif model_name in [ ' flan-t5 ' , ' t5-large ' ] :
2023-01-07 14:33:43 -05:00
model = T5ForConditionalGeneration . from_pretrained ( Path ( f " models/ { model_name } " ) ) . cuda ( )
2023-01-06 00:54:33 -05:00
else :
2023-01-07 14:33:43 -05:00
model = AutoModelForCausalLM . from_pretrained ( Path ( f " models/ { model_name } " ) , low_cpu_mem_usage = True , torch_dtype = torch . float16 ) . cuda ( )
2022-12-21 11:27:31 -05:00
2023-01-06 00:54:33 -05:00
# Loading the tokenizer
if model_name . startswith ( ' gpt4chan ' ) :
2023-01-07 14:33:43 -05:00
tokenizer = AutoTokenizer . from_pretrained ( Path ( " models/gpt-j-6B/ " ) )
2022-12-21 11:27:31 -05:00
elif model_name in [ ' flan-t5 ' ] :
2023-01-07 14:33:43 -05:00
tokenizer = T5Tokenizer . from_pretrained ( Path ( f " models/ { model_name } / " ) )
2022-12-21 11:27:31 -05:00
else :
2023-01-07 14:33:43 -05:00
tokenizer = AutoTokenizer . from_pretrained ( Path ( f " models/ { model_name } / " ) )
2022-12-21 11:27:31 -05:00
2023-01-06 00:06:59 -05:00
print ( f " Loaded the model in { ( time . time ( ) - t0 ) : .2f } seconds. " )
2022-12-21 11:27:31 -05:00
return model , tokenizer
2023-01-06 00:26:33 -05:00
# Removes empty replies from gpt4chan outputs
2022-12-21 11:27:31 -05:00
def fix_gpt4chan ( s ) :
for i in range ( 10 ) :
s = re . sub ( " --- [0-9]* \n >>[0-9]* \n --- " , " --- " , s )
s = re . sub ( " --- [0-9]* \n * \n --- " , " --- " , s )
s = re . sub ( " --- [0-9]* \n \n \n --- " , " --- " , s )
return s
2023-01-06 23:56:21 -05:00
def fix_galactica ( s ) :
s = s . replace ( r ' \ [ ' , r ' $ ' )
s = s . replace ( r ' \ ] ' , r ' $ ' )
2023-01-07 10:13:09 -05:00
s = s . replace ( r ' \ ( ' , r ' $ ' )
s = s . replace ( r ' \ ) ' , r ' $ ' )
s = s . replace ( r ' $$ ' , r ' $ ' )
2023-01-06 23:56:21 -05:00
return s
2023-01-06 00:26:33 -05:00
def generate_reply ( question , temperature , max_length , inference_settings , selected_model ) :
2023-01-06 00:06:59 -05:00
global model , tokenizer , model_name , loaded_preset , preset
2022-12-21 11:27:31 -05:00
if selected_model != model_name :
model_name = selected_model
model = None
tokenier = None
torch . cuda . empty_cache ( )
model , tokenizer = load_model ( model_name )
2023-01-06 00:06:59 -05:00
if inference_settings != loaded_preset :
2023-01-07 14:33:43 -05:00
with open ( Path ( f ' presets/ { inference_settings } .txt ' ) , ' r ' ) as infile :
2023-01-05 23:33:21 -05:00
preset = infile . read ( )
2023-01-06 00:06:59 -05:00
loaded_preset = inference_settings
2022-12-21 11:27:31 -05:00
torch . cuda . empty_cache ( )
input_text = question
input_ids = tokenizer . encode ( str ( input_text ) , return_tensors = ' pt ' ) . cuda ( )
2023-01-05 23:33:21 -05:00
output = eval ( f " model.generate(input_ids, { preset } ).cuda() " )
2022-12-21 11:27:31 -05:00
reply = tokenizer . decode ( output [ 0 ] , skip_special_tokens = True )
2023-01-06 21:14:08 -05:00
2023-01-06 18:22:26 -05:00
if model_name . lower ( ) . startswith ( ' galactica ' ) :
2023-01-06 23:56:21 -05:00
reply = fix_galactica ( reply )
2023-01-06 21:14:08 -05:00
return reply , reply , ' Only applicable for gpt4chan. '
elif model_name . lower ( ) . startswith ( ' gpt4chan ' ) :
2023-01-06 23:56:21 -05:00
reply = fix_gpt4chan ( reply )
2023-01-06 21:14:08 -05:00
return reply , ' Only applicable for galactica models. ' , generate_html ( reply )
2023-01-06 18:22:26 -05:00
else :
2023-01-06 21:14:08 -05:00
return reply , ' Only applicable for galactica models. ' , ' Only applicable for gpt4chan. '
2022-12-21 11:27:31 -05:00
2023-01-06 17:56:44 -05:00
# Choosing the default model
if args . model is not None :
model_name = args . model
else :
2023-01-06 20:05:37 -05:00
if len ( available_models ) == 0 :
2023-01-06 17:56:44 -05:00
print ( " No models are available! Please download at least one. " )
exit ( 0 )
elif len ( available_models ) == 1 :
i = 0
else :
print ( " The following models are available: \n " )
for i , model in enumerate ( available_models ) :
print ( f " { i + 1 } . { model } " )
print ( f " \n Which one do you want to load? 1- { len ( available_models ) } \n " )
i = int ( input ( ) ) - 1
model_name = available_models [ i ]
2022-12-21 11:27:31 -05:00
model , tokenizer = load_model ( model_name )
2023-01-06 17:56:44 -05:00
2022-12-21 11:27:31 -05:00
if model_name . startswith ( ' gpt4chan ' ) :
default_text = " ----- \n --- 865467536 \n Input text \n --- 865467537 \n "
else :
default_text = " Common sense questions and answers \n \n Question: \n Factual answer: "
2023-01-06 18:22:26 -05:00
if args . notebook :
with gr . Blocks ( ) as interface :
gr . Markdown (
f """
# Text generation lab
Generate text using Large Language Models .
"""
)
2023-01-06 20:05:37 -05:00
with gr . Tab ( ' Raw ' ) :
textbox = gr . Textbox ( value = default_text , lines = 23 )
with gr . Tab ( ' Markdown ' ) :
markdown = gr . Markdown ( )
2023-01-06 21:14:08 -05:00
with gr . Tab ( ' HTML ' ) :
html = gr . HTML ( )
2023-01-06 18:22:26 -05:00
btn = gr . Button ( " Generate " )
2023-01-06 20:05:37 -05:00
with gr . Row ( ) :
with gr . Column ( ) :
temp_slider = gr . Slider ( minimum = 0.0 , maximum = 1.0 , step = 0.01 , label = ' Temperature ' , value = 0.7 )
length_slider = gr . Slider ( minimum = 1 , maximum = 2000 , step = 1 , label = ' max_length ' , value = 200 )
with gr . Column ( ) :
2023-01-07 14:33:43 -05:00
preset_menu = gr . Dropdown ( choices = available_presets , value = " NovelAI-Sphinx Moth " , label = ' Preset ' )
2023-01-06 20:05:37 -05:00
model_menu = gr . Dropdown ( choices = available_models , value = model_name , label = ' Model ' )
2023-01-06 18:22:26 -05:00
2023-01-06 21:14:08 -05:00
btn . click ( generate_reply , [ textbox , temp_slider , length_slider , preset_menu , model_menu ] , [ textbox , markdown , html ] , show_progress = False )
2023-01-06 18:22:26 -05:00
else :
2023-01-06 20:05:37 -05:00
with gr . Blocks ( ) as interface :
gr . Markdown (
f """
# Text generation lab
Generate text using Large Language Models .
"""
)
with gr . Row ( ) :
with gr . Column ( ) :
textbox = gr . Textbox ( value = default_text , lines = 15 , label = ' Input ' )
temp_slider = gr . Slider ( minimum = 0.0 , maximum = 1.0 , step = 0.01 , label = ' Temperature ' , value = 0.7 )
length_slider = gr . Slider ( minimum = 1 , maximum = 2000 , step = 1 , label = ' max_length ' , value = 200 )
2023-01-07 14:33:43 -05:00
preset_menu = gr . Dropdown ( choices = available_presets , value = " NovelAI-Sphinx Moth " , label = ' Preset ' )
2023-01-06 20:05:37 -05:00
model_menu = gr . Dropdown ( choices = available_models , value = model_name , label = ' Model ' )
btn = gr . Button ( " Generate " )
with gr . Column ( ) :
with gr . Tab ( ' Raw ' ) :
output_textbox = gr . Textbox ( value = default_text , lines = 15 , label = ' Output ' )
with gr . Tab ( ' Markdown ' ) :
markdown = gr . Markdown ( )
2023-01-06 21:14:08 -05:00
with gr . Tab ( ' HTML ' ) :
html = gr . HTML ( )
2023-01-06 20:05:37 -05:00
2023-01-06 21:14:08 -05:00
btn . click ( generate_reply , [ textbox , temp_slider , length_slider , preset_menu , model_menu ] , [ output_textbox , markdown , html ] , show_progress = True )
2022-12-21 11:27:31 -05:00
interface . launch ( share = False , server_name = " 0.0.0.0 " )