More robust and error prone training (#3058)

This commit is contained in:
FartyPants 2023-07-12 14:29:43 -04:00 committed by GitHub
parent 30f37530d5
commit 9b55d3a9f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 65 additions and 6 deletions

View File

@ -478,11 +478,16 @@ def load_character(character, name1, name2, instruct=False):
if character not in ['None', '', None]:
folder = 'characters' if not instruct else 'characters/instruction-following'
picture = generate_pfp_cache(character)
filepath = None
for extension in ["yml", "yaml", "json"]:
filepath = Path(f'{folder}/{character}.{extension}')
if filepath.exists():
break
if filepath is None:
logger.error(f"Could not find character file for {character} in {folder} folder. Please check your spelling.")
return name1, name2, picture, greeting, context, turn_template.replace("\n", r"\n")
file_contents = open(filepath, 'r', encoding='utf-8').read()
data = json.loads(file_contents) if extension == "json" else yaml.safe_load(file_contents)

View File

@ -339,6 +339,7 @@ def clear_torch_cache():
def unload_model():
shared.model = shared.tokenizer = None
shared.lora_names = []
shared.model_dirty_from_training = False
clear_torch_cache()

View File

@ -12,6 +12,7 @@ tokenizer = None
is_seq2seq = False
model_name = "None"
lora_names = []
model_dirty_from_training = False
# Chat variables
stop_everything = False

View File

@ -17,6 +17,8 @@ from pathlib import Path
import gradio as gr
import torch
import transformers
from modules.models import load_model, unload_model
from datasets import Dataset, load_dataset
from peft import (
LoraConfig,
@ -60,7 +62,7 @@ train_log = {}
train_template = {}
WANT_INTERRUPT = False
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "report_to"]
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to"]
def create_train_interface():
@ -108,6 +110,7 @@ def create_train_interface():
raw_text_file = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.')
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'txt')}, 'refresh-button')
hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a hard cut between text parts. Helps prevent unwanted overlap.')
min_chars = gr.Number(label='Ignore small blocks', value=0, info='Ignore Hard Cut blocks that have less or equal characters than this number')
with gr.Row():
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
@ -119,6 +122,7 @@ def create_train_interface():
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.')
train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.')
stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
add_eos_token = gr.Checkbox(label='Add EOS token', value=False, info="Adds EOS token for each dataset item. In case of raw text, the EOS will be added at the Hard Cut")
with gr.Row():
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
@ -154,7 +158,8 @@ def create_train_interface():
refresh_table = gr.Button('Refresh the table', elem_classes="small-button")
# Training events
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, report_to]
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to]
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
start_button.click(do_train, all_params, output)
@ -264,7 +269,7 @@ def calc_trainable_parameters(model):
return trainable_params, all_param
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, report_to: str):
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str):
if shared.args.monkey_patch:
from monkeypatch.peft_tuners_lora_monkey_patch import (
@ -322,14 +327,22 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
def encode(text, add_bos_token):
result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len)
# Check if the first two tokens are BOS
if len(result) >= 2 and result[:2] == [shared.tokenizer.bos_token_id, shared.tokenizer.bos_token_id]:
result = result[1:]
if not add_bos_token and result[0] == shared.tokenizer.bos_token_id:
result = result[1:]
return result
def tokenize(prompt):
def tokenize(prompt, append_eos_token=False):
if train_only_after == '' or train_only_after not in prompt:
input_ids = encode(prompt, True)
if append_eos_token and input_ids[-1] != shared.tokenizer.eos_token_id and len(input_ids) < cutoff_len:
input_ids.append(shared.tokenizer.eos_token_id)
input_ids = [shared.tokenizer.pad_token_id] * (cutoff_len - len(input_ids)) + input_ids
labels = [1] * len(input_ids)
@ -338,6 +351,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
before_tokens = encode(prompt[:ind], True)
after_tokens = encode(prompt[ind:], False)
if append_eos_token and after_tokens[-1] != shared.tokenizer.eos_token_id:
after_tokens.append(shared.tokenizer.eos_token_id)
full_length = len(after_tokens) + len(before_tokens)
if full_length > cutoff_len:
after_tokens = after_tokens[:cutoff_len - len(before_tokens)]
@ -377,12 +393,18 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
raw_text = file.read().replace('\r', '')
cut_string = hard_cut_string.replace('\\n', '\n')
eos_added = 0
out_tokens = []
for text_part in raw_text.split(cut_string):
if text_part.strip() == '':
if len(text_part.strip()) <= min_chars:
continue
tokens = shared.tokenizer.encode(text_part)
if add_eos_token:
tokens.append(shared.tokenizer.eos_token_id)
eos_added += 1
step = cutoff_len - overlap_len
if step <= 0:
yield f"Error: overlap_len ({overlap_len}) cannot be greater than or equal to cutoff_len ({cutoff_len})"
@ -390,6 +412,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
out_tokens.extend(split_chunks(tokens, cutoff_len, step))
if eos_added > 0:
print(f"EOS added to {eos_added} text blocks")
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
text_chunks = [shared.tokenizer.decode(x) for x in out_tokens]
del out_tokens
@ -429,7 +454,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
def generate_and_tokenize_prompt(data_point):
prompt = generate_prompt(data_point)
return tokenize(prompt)
return tokenize(prompt, add_eos_token)
logger.info("Loading JSON datasets...")
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
@ -441,11 +466,33 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
# == We MUST reload model if it went through any previous training, even failed one ==
if shared.model_dirty_from_training:
selected_model = shared.model_name
if selected_model:
print("\033[1;31;1m(Model has been modified by previous training, it needs to be reloaded...)\033[0;37;0m")
try:
yield f"Reloading {selected_model}..."
unload_model()
shared.model, shared.tokenizer = load_model(shared.model_name, None)
if shared.model is not None:
print("Model reloaded OK, continue with training.")
else:
return f"Failed to load {selected_model}."
except:
exc = traceback.format_exc()
logger.error('Failed to reload the model.')
print(exc)
return exc
# == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
logger.info("Getting model ready...")
prepare_model_for_int8_training(shared.model)
# base model is now frozen and should not be reused for any other LoRA training than this one
shared.model_dirty_from_training = True
logger.info("Prepping for training...")
config = LoraConfig(
r=lora_rank,
@ -575,6 +622,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
projections_string = ", ".join([projection.replace("_proj", "") for projection in model_to_lora_modules[model_id]])
print(f"Training '{model_id}' model using ({projections_string}) projections")
if lora_all_param > 0:
print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})")
@ -582,6 +633,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
train_log.update({"base_model_class": shared.model.__class__.__name__})
train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)})
train_log.update({"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False)})
train_log.update({"projections": projections_string})
if stop_at_loss > 0:
print(f"Monitoring loss \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m")