mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Lora trainer improvements part 3 (#1098)
* add support for other model types dependent on future-peft-changes but with fallback to function now * use encoding=utf8 for training format * make shuffling optional and describe dropout a bit more * add eval_steps to control evaluation * make callbacks not depend on globals * make save steps controllable * placeholder of initial loading-existing-model support and var name cleanup * save/load parameters * last bit of cleanup * remove `gptq_bits` ref as main branch removed that setting * add higher_rank_limit option 2048 is basically unreachable due to VRAM, but i trained at 1536 with batch size = 1 on a 7B model. Note that it's in the do_train input just to save as a parameter * fix math on save_steps
This commit is contained in:
parent
ac19d5101f
commit
a3eec62b50
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import math
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@ -10,23 +11,47 @@ import torch
|
||||
import transformers
|
||||
from datasets import Dataset, load_dataset
|
||||
from peft import (LoraConfig, get_peft_model, get_peft_model_state_dict,
|
||||
prepare_model_for_int8_training)
|
||||
PeftModel, prepare_model_for_int8_training)
|
||||
|
||||
try: # This mapping is from a very recent commit, not yet released.
|
||||
from peft.utils.other import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as model_to_lora_modules
|
||||
except: # So good backup for the 3 safe model types if not yet available.
|
||||
standard_modules = ["q_proj", "v_proj"]
|
||||
model_to_lora_modules = { "llama": standard_modules, "opt": standard_modules, "gptj": standard_modules }
|
||||
|
||||
from modules import shared, ui
|
||||
|
||||
WANT_INTERRUPT = False
|
||||
CURRENT_STEPS = 0
|
||||
MAX_STEPS = 0
|
||||
CURRENT_GRADIENT_ACCUM = 1
|
||||
|
||||
|
||||
def get_dataset(path: str, ext: str):
|
||||
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lora_rank", "lora_alpha",
|
||||
"lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "do_shuffle", "higher_rank_limit"]
|
||||
MODEL_CLASSES = { # Mapping of Python class names to peft IDs
|
||||
"LlamaForCausalLM": "llama",
|
||||
"OPTForCausalLM": "opt",
|
||||
"GPTJForCausalLM": "gptj"
|
||||
}
|
||||
|
||||
|
||||
def get_datasets(path: str, ext: str):
|
||||
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
|
||||
|
||||
|
||||
def get_available_loras():
|
||||
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
||||
|
||||
|
||||
def create_train_interface():
|
||||
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
||||
lora_name = gr.Textbox(label="Name", info="The name of your new LoRA file")
|
||||
with gr.Row():
|
||||
lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
|
||||
always_override = gr.Checkbox(label='Override Existing Files', value=True, info='If the name given is the same as an existing file, checking this will replace that file. Leaving unchecked will load that file and continue from it (will use the original rank/alpha/dropout) (NOTE: Currently broken).')
|
||||
save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.')
|
||||
|
||||
with gr.Row():
|
||||
copy_from = gr.Dropdown(label='Copy parameters from', value='None', choices=get_available_loras())
|
||||
ui.create_refresh_button(copy_from, lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
|
||||
|
||||
with gr.Row():
|
||||
# TODO: Implement multi-device support.
|
||||
micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
|
||||
@ -37,25 +62,29 @@ def create_train_interface():
|
||||
learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='Learning rate, in scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
|
||||
|
||||
# TODO: What is the actual maximum rank? Likely distinct per model. This might be better to somehow be on a log scale.
|
||||
lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='LoRA Rank, or dimension count. Higher values produce a larger file with better control over the model\'s content. Smaller values produce a smaller file with less overall control. Small values like 4 or 8 are great for stylistic guidance, high values like 128 or 256 are good for teaching content upgrades. Higher ranks also require higher VRAM.')
|
||||
lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='LoRA Rank, or dimension count. Higher values produce a larger file with better control over the model\'s content. Smaller values produce a smaller file with less overall control. Small values like 4 or 8 are great for stylistic guidance, higher values like 128 or 256 are good for teaching content upgrades, extremely high values (1024+) are difficult to train but may improve fine-detail learning for large datasets. Higher ranks also require higher VRAM.')
|
||||
lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='LoRA Alpha. This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
|
||||
# TODO: Better explain what this does, in terms of real world effect especially.
|
||||
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers.')
|
||||
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
|
||||
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
|
||||
with gr.Row():
|
||||
do_shuffle = gr.Checkbox(label='Shuffle Dataset', value=True, info='If checked, the dataset will be randomly shuffled. This can help reduce overfitting.')
|
||||
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
|
||||
|
||||
with gr.Tab(label="Formatted Dataset"):
|
||||
with gr.Row():
|
||||
dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
|
||||
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||
eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
|
||||
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||
format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
|
||||
ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
|
||||
dataset = gr.Dropdown(choices=get_datasets('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
|
||||
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'json')}, 'refresh-button')
|
||||
eval_dataset = gr.Dropdown(choices=get_datasets('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
|
||||
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'json')}, 'refresh-button')
|
||||
format = gr.Dropdown(choices=get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
|
||||
ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_datasets('training/formats', 'json')}, 'refresh-button')
|
||||
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
|
||||
|
||||
with gr.Tab(label="Raw Text File"):
|
||||
with gr.Row():
|
||||
raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
|
||||
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
|
||||
raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
|
||||
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'txt')}, 'refresh-button')
|
||||
with gr.Row():
|
||||
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
|
||||
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
|
||||
@ -65,32 +94,30 @@ def create_train_interface():
|
||||
stop_button = gr.Button("Interrupt")
|
||||
|
||||
output = gr.Markdown(value="Ready")
|
||||
start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout,
|
||||
cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output])
|
||||
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout,
|
||||
cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit]
|
||||
start_button.click(do_train, all_params, [output])
|
||||
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
|
||||
|
||||
def do_copy_params(lora_name: str):
|
||||
with open(f"{shared.args.lora_dir}/{clean_path(None, lora_name)}/training_parameters.json", 'r', encoding='utf-8') as formatFile:
|
||||
params: dict[str, str] = json.load(formatFile)
|
||||
return [params[x] for x in PARAMETERS]
|
||||
|
||||
copy_from.change(do_copy_params, [copy_from], all_params)
|
||||
|
||||
def change_rank_limit(use_higher_ranks: bool):
|
||||
mult = 2 if use_higher_ranks else 1
|
||||
return {"maximum": 1024 * mult, "__type__": "update"}, {"maximum": 2048 * mult, "__type__": "update"}
|
||||
|
||||
higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha])
|
||||
|
||||
|
||||
def do_interrupt():
|
||||
global WANT_INTERRUPT
|
||||
WANT_INTERRUPT = True
|
||||
|
||||
|
||||
class Callbacks(transformers.TrainerCallback):
|
||||
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||
global CURRENT_STEPS, MAX_STEPS
|
||||
CURRENT_STEPS = state.global_step * CURRENT_GRADIENT_ACCUM
|
||||
MAX_STEPS = state.max_steps * CURRENT_GRADIENT_ACCUM
|
||||
if WANT_INTERRUPT:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||
global CURRENT_STEPS
|
||||
CURRENT_STEPS += 1
|
||||
if WANT_INTERRUPT:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
|
||||
def clean_path(base_path: str, path: str):
|
||||
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
||||
@ -102,26 +129,27 @@ def clean_path(base_path: str, path: str):
|
||||
return f'{Path(base_path).absolute()}/{path}'
|
||||
|
||||
|
||||
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
|
||||
cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int):
|
||||
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
|
||||
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
|
||||
cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, do_shuffle: bool, higher_rank_limit: bool):
|
||||
global WANT_INTERRUPT
|
||||
WANT_INTERRUPT = False
|
||||
CURRENT_STEPS = 0
|
||||
MAX_STEPS = 0
|
||||
|
||||
# == Input validation / processing ==
|
||||
yield "Prepping..."
|
||||
lora_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}"
|
||||
lora_file_path = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}"
|
||||
actual_lr = float(learning_rate)
|
||||
|
||||
model_type = type(shared.model).__name__
|
||||
if model_type != "LlamaForCausalLM":
|
||||
if model_type in MODEL_CLASSES:
|
||||
model_id = MODEL_CLASSES[model_type]
|
||||
else:
|
||||
model_id == "llama"
|
||||
if model_type == "PeftModelForCausalLM":
|
||||
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.")
|
||||
else:
|
||||
yield "LoRA training has only currently been validated for LLaMA models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
print(f"Warning: LoRA training has only currently been validated for LLaMA models. (Found model type: {model_type})")
|
||||
yield "LoRA training has only currently been validated for LLaMA, OPT, and GPT-J models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, and GPT-J models. (Found model type: {model_type})")
|
||||
time.sleep(5)
|
||||
|
||||
if shared.args.wbits > 0:
|
||||
@ -168,7 +196,6 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
|
||||
train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
|
||||
del text_chunks
|
||||
train_data = train_data.shuffle()
|
||||
eval_data = None
|
||||
|
||||
else:
|
||||
@ -180,7 +207,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
yield "**Missing format choice input, cannot continue.**"
|
||||
return
|
||||
|
||||
with open(clean_path('training/formats', f'{format}.json'), 'r') as formatFile:
|
||||
with open(clean_path('training/formats', f'{format}.json'), 'r', encoding='utf-8') as formatFile:
|
||||
format_data: dict[str, str] = json.load(formatFile)
|
||||
|
||||
def generate_prompt(data_point: dict[str, str]):
|
||||
@ -198,13 +225,18 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
|
||||
print("Loading JSON datasets...")
|
||||
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
|
||||
train_data = data['train'].shuffle().map(generate_and_tokenize_prompt)
|
||||
train_data = data['train'].map(generate_and_tokenize_prompt)
|
||||
|
||||
if eval_dataset == 'None':
|
||||
eval_data = None
|
||||
else:
|
||||
eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
|
||||
eval_data = eval_data['train'].shuffle().map(generate_and_tokenize_prompt)
|
||||
eval_data = eval_data['train'].map(generate_and_tokenize_prompt)
|
||||
if do_shuffle:
|
||||
eval_data = eval_data.shuffle()
|
||||
|
||||
if do_shuffle:
|
||||
train_data = train_data.shuffle()
|
||||
|
||||
# == Start prepping the model itself ==
|
||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||
@ -215,19 +247,44 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
# TODO: Should target_modules be configurable?
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
target_modules=model_to_lora_modules[model_id],
|
||||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
try:
|
||||
lora_model = get_peft_model(shared.model, config)
|
||||
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
|
||||
print("Loading existing LoRA file...")
|
||||
lora_model = PeftModel.from_pretrained(shared.model, lora_file_path)
|
||||
else:
|
||||
print("Creating new LoRA model...")
|
||||
lora_model = get_peft_model(shared.model, config)
|
||||
except:
|
||||
yield traceback.format_exc()
|
||||
return
|
||||
|
||||
class Tracked():
|
||||
def __init__(self):
|
||||
self.current_steps = 0
|
||||
self.max_steps = 0
|
||||
|
||||
tracked = Tracked()
|
||||
|
||||
class Callbacks(transformers.TrainerCallback):
|
||||
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||
tracked.current_steps = state.global_step * CURRENT_GRADIENT_ACCUM
|
||||
tracked.max_steps = state.max_steps * CURRENT_GRADIENT_ACCUM
|
||||
if WANT_INTERRUPT:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||
tracked.current_steps += 1
|
||||
if WANT_INTERRUPT:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=lora_model,
|
||||
train_dataset=train_data,
|
||||
@ -235,18 +292,16 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
args=transformers.TrainingArguments(
|
||||
per_device_train_batch_size=micro_batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
# TODO: Should more of these be configurable? Probably.
|
||||
warmup_steps=100,
|
||||
num_train_epochs=epochs,
|
||||
learning_rate=actual_lr,
|
||||
fp16=False if shared.args.cpu else True,
|
||||
logging_steps=20,
|
||||
logging_steps=5,
|
||||
evaluation_strategy="steps" if eval_data is not None else "no",
|
||||
eval_steps=math.ceil(eval_steps / gradient_accumulation_steps) if eval_data is not None else None,
|
||||
save_strategy="steps",
|
||||
eval_steps=200 if eval_data is not None else None,
|
||||
save_steps=200,
|
||||
output_dir=lora_name,
|
||||
save_total_limit=3,
|
||||
save_steps=math.ceil(save_steps / gradient_accumulation_steps),
|
||||
output_dir=lora_file_path,
|
||||
load_best_model_at_end=True if eval_data is not None else False,
|
||||
# TODO: Enable multi-device support
|
||||
ddp_find_unused_parameters=None,
|
||||
@ -265,8 +320,12 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||
lora_model = torch.compile(lora_model)
|
||||
|
||||
# == Save parameters for reuse ==
|
||||
with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file:
|
||||
vars = locals()
|
||||
json.dump({x: vars[x] for x in PARAMETERS}, file)
|
||||
|
||||
# == Main run and monitor loop ==
|
||||
# TODO: save/load checkpoints to resume from?
|
||||
print("Starting training...")
|
||||
yield "Starting..."
|
||||
if WANT_INTERRUPT:
|
||||
@ -286,30 +345,30 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
if WANT_INTERRUPT:
|
||||
yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
|
||||
|
||||
elif CURRENT_STEPS != last_step:
|
||||
last_step = CURRENT_STEPS
|
||||
elif tracked.current_steps != last_step:
|
||||
last_step = tracked.current_steps
|
||||
time_elapsed = time.perf_counter() - start_time
|
||||
if time_elapsed <= 0:
|
||||
timer_info = ""
|
||||
total_time_estimate = 999
|
||||
else:
|
||||
its = CURRENT_STEPS / time_elapsed
|
||||
its = tracked.current_steps / time_elapsed
|
||||
if its > 1:
|
||||
timer_info = f"`{its:.2f}` it/s"
|
||||
else:
|
||||
timer_info = f"`{1.0/its:.2f}` s/it"
|
||||
total_time_estimate = (1.0 / its) * (MAX_STEPS)
|
||||
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
|
||||
total_time_estimate = (1.0 / its) * (tracked.max_steps)
|
||||
yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
|
||||
|
||||
print("Training complete, saving...")
|
||||
lora_model.save_pretrained(lora_name)
|
||||
lora_model.save_pretrained(lora_file_path)
|
||||
|
||||
if WANT_INTERRUPT:
|
||||
print("Training interrupted.")
|
||||
yield f"Interrupted. Incomplete LoRA saved to `{lora_name}`"
|
||||
yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`"
|
||||
else:
|
||||
print("Training complete!")
|
||||
yield f"Done! LoRA saved to `{lora_name}`"
|
||||
yield f"Done! LoRA saved to `{lora_file_path}`"
|
||||
|
||||
|
||||
def split_chunks(arr, step):
|
||||
|
Loading…
Reference in New Issue
Block a user