mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Several Training Enhancements (#2868)
This commit is contained in:
parent
95212edf1f
commit
21c189112c
@ -49,9 +49,10 @@ except:
|
||||
}
|
||||
|
||||
train_log = {}
|
||||
train_template = {}
|
||||
|
||||
WANT_INTERRUPT = False
|
||||
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after"]
|
||||
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss"]
|
||||
|
||||
|
||||
def create_train_interface():
|
||||
@ -109,6 +110,7 @@ def create_train_interface():
|
||||
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.')
|
||||
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.')
|
||||
train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.')
|
||||
stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
|
||||
|
||||
with gr.Row():
|
||||
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
|
||||
@ -142,7 +144,7 @@ def create_train_interface():
|
||||
refresh_table = gr.Button('Refresh the table', elem_classes="small-button")
|
||||
|
||||
# Training events
|
||||
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after]
|
||||
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss]
|
||||
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
||||
start_button.click(do_train, all_params, output)
|
||||
stop_button.click(do_interrupt, None, None, queue=False)
|
||||
@ -206,7 +208,7 @@ def clean_path(base_path: str, path: str):
|
||||
return f'{Path(base_path).absolute()}/{path}'
|
||||
|
||||
|
||||
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str):
|
||||
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float):
|
||||
|
||||
if shared.args.monkey_patch:
|
||||
from monkeypatch.peft_tuners_lora_monkey_patch import (
|
||||
@ -296,9 +298,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
"attention_mask": input_ids.ne(shared.tokenizer.pad_token_id),
|
||||
}
|
||||
|
||||
train_template.clear()
|
||||
|
||||
# == Prep the dataset, format, etc ==
|
||||
if raw_text_file not in ['None', '']:
|
||||
logger.info("Loading raw text file dataset...")
|
||||
|
||||
train_template["template_type"] = "raw_text"
|
||||
|
||||
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
||||
raw_text = file.read().replace('\r', '')
|
||||
|
||||
@ -330,7 +337,6 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
|
||||
del text_chunks
|
||||
eval_data = None
|
||||
|
||||
else:
|
||||
if dataset in ['None', '']:
|
||||
yield "**Missing dataset choice input, cannot continue.**"
|
||||
@ -340,9 +346,16 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
yield "**Missing format choice input, cannot continue.**"
|
||||
return
|
||||
|
||||
with open(clean_path('training/formats', f'{format}.json'), 'r', encoding='utf-8') as formatFile:
|
||||
train_template["template_type"] = "dataset"
|
||||
|
||||
with open(clean_path('training/formats', f'{format}.json'), 'r', encoding='utf-8-sig') as formatFile:
|
||||
format_data: dict[str, str] = json.load(formatFile)
|
||||
|
||||
# == store training prompt ==
|
||||
for _, value in format_data.items():
|
||||
prompt_key = f"template_{len(train_template)}"
|
||||
train_template[prompt_key] = value
|
||||
|
||||
def generate_prompt(data_point: dict[str, str]):
|
||||
for options, data in format_data.items():
|
||||
if set(options.split(',')) == set(x[0] for x in data_point.items() if (x[1] is not None and len(x[1].strip()) > 0)):
|
||||
@ -369,7 +382,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
# == Start prepping the model itself ==
|
||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||
logger.info("Getting model ready...")
|
||||
prepare_model_for_kbit_training(shared.model)
|
||||
prepare_model_for_int8_training(shared.model)
|
||||
|
||||
logger.info("Prepping for training...")
|
||||
config = LoraConfig(
|
||||
@ -421,7 +434,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
# Save log
|
||||
with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_log.json", 'w', encoding='utf-8') as file:
|
||||
json.dump(train_log, file, indent=2)
|
||||
|
||||
# == Save training prompt ==
|
||||
with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_prompt.json", 'w', encoding='utf-8') as file:
|
||||
json.dump(train_template, file, indent=2)
|
||||
|
||||
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||
tracked.current_steps += 1
|
||||
@ -431,6 +446,17 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
|
||||
def on_log(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, logs, **kwargs):
|
||||
train_log.update(logs)
|
||||
train_log.update({"current_steps": tracked.current_steps})
|
||||
if WANT_INTERRUPT:
|
||||
print("\033[1;31;1mInterrupted by user\033[0;37;0m")
|
||||
|
||||
print(f"\033[1;30;40mStep: {tracked.current_steps} \033[0;37;0m", end='')
|
||||
if 'loss' in logs:
|
||||
loss = float(logs['loss'])
|
||||
if loss <= stop_at_loss:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
print(f"\033[1;31;1mStop Loss {stop_at_loss} reached.\033[0;37;0m")
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=lora_model,
|
||||
@ -444,7 +470,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
learning_rate=actual_lr,
|
||||
fp16=False if shared.args.cpu else True,
|
||||
optim=optimizer,
|
||||
logging_steps=5,
|
||||
logging_steps=2 if stop_at_loss > 0 else 5,
|
||||
evaluation_strategy="steps" if eval_data is not None else "no",
|
||||
eval_steps=math.ceil(eval_steps / gradient_accumulation_steps) if eval_data is not None else None,
|
||||
save_strategy="steps" if eval_data is not None else "no",
|
||||
@ -469,9 +495,22 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
vars = locals()
|
||||
json.dump({x: vars[x] for x in PARAMETERS}, file, indent=2)
|
||||
|
||||
# == Save training prompt ==
|
||||
with open(f"{lora_file_path}/training_prompt.json", 'w', encoding='utf-8') as file:
|
||||
json.dump(train_template, file, indent=2)
|
||||
|
||||
# == Main run and monitor loop ==
|
||||
logger.info("Starting training...")
|
||||
yield "Starting..."
|
||||
|
||||
train_log.update({"base_model_name": shared.model_name})
|
||||
train_log.update({"base_model_class": shared.model.__class__.__name__})
|
||||
train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)})
|
||||
train_log.update({"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False)})
|
||||
|
||||
if stop_at_loss > 0:
|
||||
print(f"Monitoring loss \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m")
|
||||
|
||||
if WANT_INTERRUPT:
|
||||
yield "Interrupted before start."
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user