'
for post in posts:
@@ -98,13 +104,15 @@ def generate_4chan_html(f):
return output
+
def make_thumbnail(image):
- image = image.resize((350, round(image.size[1]/image.size[0]*350)), Image.Resampling.LANCZOS)
+ image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS)
if image.size[1] > 470:
image = ImageOps.fit(image, (350, 470), Image.ANTIALIAS)
return image
+
def get_image_cache(path):
cache_folder = Path("cache")
if not cache_folder.exists():
@@ -119,9 +127,10 @@ def get_image_cache(path):
return image_cache[path][1]
+
def generate_instruct_html(history):
output = f'
'
- for i,_row in enumerate(history[::-1]):
+ for i, _row in enumerate(history[::-1]):
row = [convert_to_markdown(entry) for entry in _row]
output += f"""
@@ -134,7 +143,7 @@ def generate_instruct_html(history):
"""
- if len(row[0]) == 0: # don't display empty user messages
+ if len(row[0]) == 0: # don't display empty user messages
continue
output += f"""
@@ -151,6 +160,7 @@ def generate_instruct_html(history):
return output
+
def generate_cai_chat_html(history, name1, name2, reset_cache=False):
output = f'
'
@@ -159,7 +169,7 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
img_bot = f'
' if Path("cache/pfp_character.png").exists() else ''
img_me = f'
' if Path("cache/pfp_me.png").exists() else ''
- for i,_row in enumerate(history[::-1]):
+ for i, _row in enumerate(history[::-1]):
row = [convert_to_markdown(entry) for entry in _row]
output += f"""
@@ -178,7 +188,7 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
"""
- if len(row[0]) == 0: # don't display empty user messages
+ if len(row[0]) == 0: # don't display empty user messages
continue
output += f"""
@@ -200,9 +210,11 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
output += "
"
return output
+
def generate_chat_html(history, name1, name2):
return generate_cai_chat_html(history, name1, name2)
+
def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
if mode == "cai-chat":
return generate_cai_chat_html(history, name1, name2, reset_cache)
diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py
index 4f491329..9461db10 100644
--- a/modules/llamacpp_model.py
+++ b/modules/llamacpp_model.py
@@ -50,9 +50,9 @@ class LlamaCppModel:
params.top_k = top_k
params.temp = temperature
params.repeat_penalty = repetition_penalty
- #params.repeat_last_n = repeat_last_n
+ # params.repeat_last_n = repeat_last_n
- #self.model.params = params
+ # self.model.params = params
self.model.add_bos()
self.model.update_input(context)
diff --git a/modules/llamacpp_model_alternative.py b/modules/llamacpp_model_alternative.py
index 40576113..8fea2ab4 100644
--- a/modules/llamacpp_model_alternative.py
+++ b/modules/llamacpp_model_alternative.py
@@ -1,13 +1,11 @@
'''
-Based on
+Based on
https://github.com/abetlen/llama-cpp-python
Documentation:
https://abetlen.github.io/llama-cpp-python/
'''
-import multiprocessing
-
from llama_cpp import Llama
from modules import shared
@@ -31,7 +29,7 @@ class LlamaCppModel:
self.model = Llama(**params)
# This is ugly, but the model and the tokenizer are the same object in this library.
- return result, result
+ return result, result
def encode(self, string):
if type(string) is str:
diff --git a/modules/models.py b/modules/models.py
index 1bf6fc37..5e2b0989 100644
--- a/modules/models.py
+++ b/modules/models.py
@@ -34,7 +34,7 @@ if shared.args.deepspeed:
torch.cuda.set_device(local_rank)
deepspeed.init_distributed()
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
- dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
+ dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
def load_model(model_name):
@@ -83,7 +83,7 @@ def load_model(model_name):
elif shared.args.deepspeed:
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
- model.module.eval() # Inference
+ model.module.eval() # Inference
print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
# RMKV model (not on HuggingFace)
@@ -132,7 +132,7 @@ def load_model(model_name):
params["torch_dtype"] = torch.float16
if shared.args.gpu_memory:
- memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory))
+ memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
max_memory = {}
for i in range(len(memory_map)):
@@ -140,13 +140,13 @@ def load_model(model_name):
max_memory['cpu'] = max_cpu_memory
params['max_memory'] = max_memory
elif shared.args.auto_devices:
- total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024*1024))
- suggestion = round((total_mem-1000) / 1000) * 1000
+ total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
+ suggestion = round((total_mem - 1000) / 1000) * 1000
if total_mem - suggestion < 800:
suggestion -= 1000
- suggestion = int(round(suggestion/1000))
+ suggestion = int(round(suggestion / 1000))
print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
-
+
max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
params['max_memory'] = max_memory
@@ -161,10 +161,10 @@ def load_model(model_name):
model = AutoModelForCausalLM.from_config(config)
model.tie_weights()
params['device_map'] = infer_auto_device_map(
- model,
- dtype=torch.int8,
+ model,
+ dtype=torch.int8,
max_memory=params['max_memory'],
- no_split_module_classes = model._no_split_modules
+ no_split_module_classes=model._no_split_modules
)
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
@@ -181,6 +181,7 @@ def load_model(model_name):
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer
+
def load_soft_prompt(name):
if name == 'None':
shared.soft_prompt = False
diff --git a/modules/shared.py b/modules/shared.py
index 902d7609..7ff1ca28 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -61,6 +61,7 @@ settings = {
}
}
+
def str2bool(v):
if isinstance(v, bool):
return v
@@ -71,7 +72,8 @@ def str2bool(v):
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
-parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
+
+parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
# Basic settings
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
@@ -145,5 +147,6 @@ if args.cai_chat:
print("Warning: --cai-chat is deprecated. Use --chat instead.")
args.chat = True
+
def is_chat():
return args.chat
diff --git a/modules/text_generation.py b/modules/text_generation.py
index b8885abe..5fead483 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -16,11 +16,12 @@ from modules.models import local_rank
def get_max_prompt_length(tokens):
- max_length = 2048-tokens
+ max_length = 2048 - tokens
if shared.soft_prompt:
max_length -= shared.soft_prompt_tensor.shape[1]
return max_length
+
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
if any((shared.is_RWKV, shared.is_llamacpp)):
input_ids = shared.tokenizer.encode(str(prompt))
@@ -30,7 +31,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
- input_ids = input_ids[:,1:]
+ input_ids = input_ids[:, 1:]
if shared.args.cpu:
return input_ids
@@ -44,6 +45,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
else:
return input_ids.cuda()
+
def decode(output_ids):
# Open Assistant relies on special tokens like <|endoftext|>
if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
@@ -53,14 +55,17 @@ def decode(output_ids):
reply = reply.replace(r'<|endoftext|>', '')
return reply
+
def generate_softprompt_input_tensors(input_ids):
inputs_embeds = shared.model.transformer.wte(input_ids)
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
- #filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
+ # filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
return inputs_embeds, filler_input_ids
# Removes empty replies from gpt4chan outputs
+
+
def fix_gpt4chan(s):
for i in range(10):
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
@@ -69,6 +74,8 @@ def fix_gpt4chan(s):
return s
# Fix the LaTeX equations in galactica
+
+
def fix_galactica(s):
s = s.replace(r'\[', r'$')
s = s.replace(r'\]', r'$')
@@ -79,6 +86,7 @@ def fix_galactica(s):
s = re.sub(r"\n{3,}", "\n\n", s)
return s
+
def formatted_outputs(reply, model_name):
if not shared.is_chat():
if 'galactica' in model_name.lower():
@@ -92,20 +100,24 @@ def formatted_outputs(reply, model_name):
else:
return reply
+
def clear_torch_cache():
gc.collect()
if not shared.args.cpu:
torch.cuda.empty_cache()
+
def set_manual_seed(seed):
if seed != -1:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
+
def stop_everything_event():
shared.stop_everything = True
+
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
clear_torch_cache()
set_manual_seed(generate_state['seed'])
@@ -128,7 +140,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
try:
if shared.args.no_stream:
reply = shared.model.generate(context=question, **generate_params)
- output = original_question+reply
+ output = original_question + reply
if not shared.is_chat():
reply = original_question + apply_extensions(reply, "output")
yield formatted_outputs(reply, shared.model_name)
@@ -139,7 +151,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
# RWKV has proper streaming, which is very nice.
# No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
- output = original_question+reply
+ output = original_question + reply
if not shared.is_chat():
reply = original_question + apply_extensions(reply, "output")
yield formatted_outputs(reply, shared.model_name)
@@ -240,7 +252,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
- for i in range(generate_state['max_new_tokens']//8+1):
+ for i in range(generate_state['max_new_tokens'] // 8 + 1):
clear_torch_cache()
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]
@@ -271,6 +283,6 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
finally:
t1 = time.time()
original_tokens = len(original_input_ids[0])
- new_tokens = len(output)-original_tokens
+ new_tokens = len(output) - original_tokens
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
return
diff --git a/modules/training.py b/modules/training.py
index 220428b1..9880cf00 100644
--- a/modules/training.py
+++ b/modules/training.py
@@ -19,9 +19,11 @@ CURRENT_STEPS = 0
MAX_STEPS = 0
CURRENT_GRADIENT_ACCUM = 1
+
def get_dataset(path: str, ext: str):
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
+
def create_train_interface():
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
lora_name = gr.Textbox(label="Name", info="The name of your new LoRA file")
@@ -44,16 +46,16 @@ def create_train_interface():
with gr.Tab(label="Formatted Dataset"):
with gr.Row():
dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
- ui.create_refresh_button(dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
+ ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
- ui.create_refresh_button(eval_dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
+ ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
- ui.create_refresh_button(format, lambda : None, lambda : {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
+ ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
with gr.Tab(label="Raw Text File"):
with gr.Row():
raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
- ui.create_refresh_button(raw_text_file, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
+ ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
with gr.Row():
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
@@ -67,10 +69,12 @@ def create_train_interface():
cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output])
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
+
def do_interrupt():
global WANT_INTERRUPT
WANT_INTERRUPT = True
+
class Callbacks(transformers.TrainerCallback):
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
global CURRENT_STEPS, MAX_STEPS
@@ -79,6 +83,7 @@ class Callbacks(transformers.TrainerCallback):
if WANT_INTERRUPT:
control.should_epoch_stop = True
control.should_training_stop = True
+
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
global CURRENT_STEPS
CURRENT_STEPS += 1
@@ -86,6 +91,7 @@ class Callbacks(transformers.TrainerCallback):
control.should_epoch_stop = True
control.should_training_stop = True
+
def clean_path(base_path: str, path: str):
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
@@ -95,6 +101,7 @@ def clean_path(base_path: str, path: str):
return path
return f'{Path(base_path).absolute()}/{path}'
+
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int):
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
@@ -124,7 +131,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
elif not shared.args.load_in_8bit:
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
- time.sleep(2) # Give it a moment for the message to show in UI before continuing
+ time.sleep(2) # Give it a moment for the message to show in UI before continuing
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
yield "Cannot input zeroes."
@@ -148,7 +155,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file:
raw_text = file.read()
tokens = shared.tokenizer.encode(raw_text)
- del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
+ del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
for i in range(1, len(tokens)):
@@ -197,18 +204,18 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
else:
eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
eval_data = eval_data['train'].shuffle().map(generate_and_tokenize_prompt)
-
+
# == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
print("Getting model ready...")
prepare_model_for_int8_training(shared.model)
-
+
print("Prepping for training...")
config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
# TODO: Should target_modules be configurable?
- target_modules=[ "q_proj", "v_proj" ],
+ target_modules=["q_proj", "v_proj"],
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM"
@@ -289,7 +296,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
timer_info = f"`{its:.2f}` it/s"
else:
timer_info = f"`{1.0/its:.2f}` s/it"
- total_time_estimate = (1.0/its) * (MAX_STEPS)
+ total_time_estimate = (1.0 / its) * (MAX_STEPS)
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
print("Training complete, saving...")
@@ -302,10 +309,12 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
print("Training complete!")
yield f"Done! LoRA saved to `{lora_name}`"
+
def split_chunks(arr, step):
for i in range(0, len(arr), step):
yield arr[i:i + step]
+
def cut_chunk_for_newline(chunk: str, max_length: int):
if '\n' not in chunk:
return chunk
@@ -319,6 +328,7 @@ def cut_chunk_for_newline(chunk: str, max_length: int):
chunk = chunk[:last_newline]
return chunk
+
def format_time(seconds: float):
if seconds < 120:
return f"`{seconds:.0f}` seconds"
diff --git a/modules/ui.py b/modules/ui.py
index 80bd7c1c..def1faaf 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -13,6 +13,7 @@ with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
chat_js = f.read()
+
class ToolButton(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, fits inside gradio forms"""
@@ -22,6 +23,7 @@ class ToolButton(gr.Button, gr.components.FormComponent):
def get_block_name(self):
return "button"
+
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh():
refresh_method()
diff --git a/server.py b/server.py
index 4ba5ba82..4fdeb80b 100644
--- a/server.py
+++ b/server.py
@@ -34,15 +34,18 @@ if settings_file is not None:
for item in new_settings:
shared.settings[item] = new_settings[item]
+
def get_available_models():
if shared.args.flexgen:
return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower)
else:
return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
+
def get_available_presets():
return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower)
+
def get_available_prompts():
prompts = []
prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
@@ -50,10 +53,12 @@ def get_available_prompts():
prompts += ['None']
return prompts
+
def get_available_characters():
paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
+
def get_available_instruction_templates():
path = "characters/instruction-following"
paths = []
@@ -61,19 +66,24 @@ def get_available_instruction_templates():
paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
+
def get_available_extensions():
- return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
+ return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
+
def get_available_softprompts():
return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower)
+
def get_available_loras():
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
+
def unload_model():
shared.model = shared.tokenizer = None
clear_torch_cache()
+
def load_model_wrapper(selected_model):
if selected_model != shared.model_name:
shared.model_name = selected_model
@@ -84,10 +94,12 @@ def load_model_wrapper(selected_model):
return selected_model
+
def load_lora_wrapper(selected_lora):
add_lora_to_model(selected_lora)
return selected_lora
+
def load_preset_values(preset_menu, state, return_dict=False):
generate_params = {
'do_sample': True,
@@ -118,6 +130,7 @@ def load_preset_values(preset_menu, state, return_dict=False):
state.update(generate_params)
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
+
def upload_soft_prompt(file):
with zipfile.ZipFile(io.BytesIO(file)) as zf:
zf.extract('meta.json')
@@ -130,12 +143,14 @@ def upload_soft_prompt(file):
return name
+
def save_prompt(text):
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt"
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
f.write(text)
return f"Saved to prompts/{fname}"
+
def load_prompt(fname):
if fname in ['None', '']:
return ''
@@ -146,12 +161,13 @@ def load_prompt(fname):
text = text[:-1]
return text
+
def create_prompt_menus():
with gr.Row():
with gr.Column():
with gr.Row():
shared.gradio['prompt_menu'] = gr.Dropdown(choices=get_available_prompts(), value='None', label='Prompt')
- ui.create_refresh_button(shared.gradio['prompt_menu'], lambda : None, lambda : {'choices': get_available_prompts()}, 'refresh-button')
+ ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': get_available_prompts()}, 'refresh-button')
with gr.Column():
with gr.Column():
@@ -161,20 +177,22 @@ def create_prompt_menus():
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
+
def create_model_menus():
with gr.Row():
with gr.Column():
with gr.Row():
shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
- ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button')
+ ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': get_available_models()}, 'refresh-button')
with gr.Column():
with gr.Row():
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
- ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
+ ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
+
def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
@@ -185,7 +203,7 @@ def create_settings_menus(default_preset):
with gr.Column():
with gr.Row():
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
- ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
+ ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': get_available_presets()}, 'refresh-button')
with gr.Column():
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
@@ -196,12 +214,12 @@ def create_settings_menus(default_preset):
with gr.Row():
with gr.Column():
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
- shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p')
- shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
- shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p')
+ shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p')
+ shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k')
+ shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p')
with gr.Column():
- shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
- shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'],step=0.01,label='encoder_repetition_penalty')
+ shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty')
+ shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty')
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
@@ -209,7 +227,6 @@ def create_settings_menus(default_preset):
with gr.Box():
gr.Markdown('Contrastive search')
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
-
with gr.Box():
gr.Markdown('Beam search (uses a lot of VRAM)')
with gr.Row():
@@ -219,11 +236,10 @@ def create_settings_menus(default_preset):
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
-
with gr.Accordion('Soft prompt', open=False):
with gr.Row():
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
- ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button')
+ ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda: None, lambda: {'choices': get_available_softprompts()}, 'refresh-button')
gr.Markdown('Upload a soft prompt (.zip format):')
with gr.Row():
@@ -233,6 +249,7 @@ def create_settings_menus(default_preset):
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
+
def set_interface_arguments(interface_mode, extensions, bool_active):
modes = ["default", "notebook", "chat", "cai_chat"]
cmd_list = vars(shared.args)
@@ -251,6 +268,7 @@ def set_interface_arguments(interface_mode, extensions, bool_active):
shared.need_restart = True
+
available_models = get_available_models()
available_presets = get_available_presets()
available_characters = get_available_characters()
@@ -284,7 +302,7 @@ else:
for i, model in enumerate(available_models):
print(f'{i+1}. {model}')
print(f'\nWhich one do you want to load? 1-{len(available_models)}\n')
- i = int(input())-1
+ i = int(input()) - 1
print()
shared.model_name = available_models[i]
shared.model, shared.tokenizer = load_model(shared.model_name)
@@ -297,15 +315,15 @@ if shared.lora_name != "None":
default_text = load_prompt(shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')])
else:
default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
-title ='Text generation web UI'
+title = 'Text generation web UI'
+
def create_interface():
-
gen_events = []
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
extensions_module.load_extensions()
- with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
+ with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
if shared.is_chat():
shared.gradio['Chat input'] = gr.State()
with gr.Tab("Text generation", elem_id="main"):
@@ -342,7 +360,7 @@ def create_interface():
shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None)
with gr.Row():
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
- ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
+ ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': get_available_characters()}, 'refresh-button')
with gr.Row():
with gr.Tab('Chat history'):
@@ -399,11 +417,11 @@ def create_interface():
# Clear history with confirmation
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
- shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
- shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
+ shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
+ shared.gradio['Clear history-confirm'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display'])
- shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
- shared.gradio['Chat mode'].change(lambda x : gr.update(visible= x=='instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates'])
+ shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
+ shared.gradio['Chat mode'].change(lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates'])
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
@@ -412,10 +430,10 @@ def create_interface():
# Clearing stuff and saving the history
for i in ['Generate', 'Regenerate', 'Replace last reply']:
shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
- shared.gradio[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
- shared.gradio['Clear history-confirm'].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
+ shared.gradio[i].click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
+ shared.gradio['Clear history-confirm'].click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
- shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
+ shared.gradio['textbox'].submit(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
@@ -430,7 +448,7 @@ def create_interface():
shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
- shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
+ shared.gradio['interface'].load(lambda: chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True)
elif shared.args.notebook:
@@ -526,7 +544,7 @@ def create_interface():
shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary")
shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None)
- shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'
Reloading...
\'; setTimeout(function(){location.reload()},2500); return []}')
+ shared.gradio['reset_interface'].click(lambda: None, None, None, _js='() => {document.body.innerHTML=\'
Reloading...
\'; setTimeout(function(){location.reload()},2500); return []}')
if shared.args.extensions is not None:
extensions_module.create_extensions_block()
@@ -562,6 +580,7 @@ def create_interface():
else:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
+
create_interface()
while True: