mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
4-Bit LoRA training + several new training options and fixes
This commit is contained in:
parent
702fe92d42
commit
ee30625cd1
@ -10,8 +10,7 @@ import gradio as gr
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
from peft import (LoraConfig, PeftModel, get_peft_model,
|
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict, prepare_model_for_int8_training
|
||||||
get_peft_model_state_dict, prepare_model_for_int8_training)
|
|
||||||
|
|
||||||
from modules import shared, ui
|
from modules import shared, ui
|
||||||
|
|
||||||
@ -27,7 +26,7 @@ except:
|
|||||||
|
|
||||||
WANT_INTERRUPT = False
|
WANT_INTERRUPT = False
|
||||||
|
|
||||||
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "do_shuffle", "higher_rank_limit"]
|
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "do_shuffle", "higher_rank_limit", "warmup_steps", "optimizer"]
|
||||||
|
|
||||||
# Mapping of Python class names to peft IDs
|
# Mapping of Python class names to peft IDs
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
@ -49,7 +48,7 @@ def create_train_interface():
|
|||||||
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
|
lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
|
||||||
always_override = gr.Checkbox(label='Override Existing Files', value=True, info='If the name given is the same as an existing file, checking this will replace that file. Leaving unchecked will load that file and continue from it (will use the original rank/alpha/dropout) (NOTE: Currently broken).')
|
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name given is the same as an existing file, checking this will replace that file. Leaving unchecked will load that file and continue from it (must use the same rank value as the original had).')
|
||||||
save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.')
|
save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -64,19 +63,15 @@ def create_train_interface():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.')
|
epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.')
|
||||||
learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='Learning rate, in scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
|
learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='Learning rate, in scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
|
||||||
|
lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.')
|
||||||
|
|
||||||
# TODO: What is the actual maximum rank? Likely distinct per model. This might be better to somehow be on a log scale.
|
# TODO: What is the actual maximum rank? Likely distinct per model. This might be better to somehow be on a log scale.
|
||||||
lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='LoRA Rank, or dimension count. Higher values produce a larger file with better control over the model\'s content. Smaller values produce a smaller file with less overall control. Small values like 4 or 8 are great for stylistic guidance, higher values like 128 or 256 are good for teaching content upgrades, extremely high values (1024+) are difficult to train but may improve fine-detail learning for large datasets. Higher ranks also require higher VRAM.')
|
lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='LoRA Rank, or dimension count. Higher values produce a larger file with better control over the model\'s content. Smaller values produce a smaller file with less overall control. Small values like 4 or 8 are great for stylistic guidance, higher values like 128 or 256 are good for teaching content upgrades, extremely high values (1024+) are difficult to train but may improve fine-detail learning for large datasets. Higher ranks also require higher VRAM.')
|
||||||
lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='LoRA Alpha. This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
|
lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='LoRA Alpha. This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
|
||||||
|
|
||||||
# TODO: Better explain what this does, in terms of real world effect especially.
|
|
||||||
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
|
|
||||||
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
|
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
|
||||||
with gr.Row():
|
|
||||||
do_shuffle = gr.Checkbox(label='Shuffle Dataset', value=True, info='If checked, the dataset will be randomly shuffled. This can help reduce overfitting.')
|
|
||||||
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
|
|
||||||
|
|
||||||
with gr.Tab(label="Formatted Dataset"):
|
with gr.Tab(label='Formatted Dataset'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
dataset = gr.Dropdown(choices=get_datasets('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
|
dataset = gr.Dropdown(choices=get_datasets('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
|
||||||
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'json')}, 'refresh-button')
|
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'json')}, 'refresh-button')
|
||||||
@ -87,7 +82,7 @@ def create_train_interface():
|
|||||||
|
|
||||||
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
|
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
|
||||||
|
|
||||||
with gr.Tab(label="Raw Text File"):
|
with gr.Tab(label='Raw Text File'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
|
raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
|
||||||
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'txt')}, 'refresh-button')
|
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'txt')}, 'refresh-button')
|
||||||
@ -96,24 +91,23 @@ def create_train_interface():
|
|||||||
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
|
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
|
||||||
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
|
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
|
||||||
|
|
||||||
|
with gr.Accordion(label='Advanced Options', open=False):
|
||||||
|
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
|
||||||
|
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.')
|
||||||
|
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.')
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
do_shuffle = gr.Checkbox(label='Shuffle Dataset', value=True, info='If checked, the dataset will be randomly shuffled. This can help reduce overfitting.')
|
||||||
|
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_button = gr.Button("Start LoRA Training")
|
start_button = gr.Button("Start LoRA Training")
|
||||||
stop_button = gr.Button("Interrupt")
|
stop_button = gr.Button("Interrupt")
|
||||||
|
|
||||||
output = gr.Markdown(value="Ready")
|
output = gr.Markdown(value="Ready")
|
||||||
|
|
||||||
def do_copy_params(lora_name: str):
|
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit, warmup_steps, optimizer]
|
||||||
with open(f"{shared.args.lora_dir}/{clean_path(None, lora_name)}/training_parameters.json", 'r', encoding='utf-8') as formatFile:
|
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
||||||
params: dict[str, str] = json.load(formatFile)
|
|
||||||
|
|
||||||
return [params[x] for x in PARAMETERS]
|
|
||||||
|
|
||||||
def change_rank_limit(use_higher_ranks: bool):
|
|
||||||
mult = 2 if use_higher_ranks else 1
|
|
||||||
return {"maximum": 1024 * mult, "__type__": "update"}, {"maximum": 2048 * mult, "__type__": "update"}
|
|
||||||
|
|
||||||
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit]
|
|
||||||
copy_from.change(do_copy_params, copy_from, all_params)
|
|
||||||
start_button.click(do_train, all_params, output)
|
start_button.click(do_train, all_params, output)
|
||||||
stop_button.click(do_interrupt, None, None, queue=False)
|
stop_button.click(do_interrupt, None, None, queue=False)
|
||||||
higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha])
|
higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha])
|
||||||
@ -124,6 +118,29 @@ def do_interrupt():
|
|||||||
WANT_INTERRUPT = True
|
WANT_INTERRUPT = True
|
||||||
|
|
||||||
|
|
||||||
|
def do_copy_params(lora_name: str, *args):
|
||||||
|
f_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}/training_parameters.json"
|
||||||
|
if Path(f_name).is_file():
|
||||||
|
with open(f_name, 'r', encoding='utf-8') as format_file:
|
||||||
|
params: dict[str, str] = json.load(format_file)
|
||||||
|
else:
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
result = list()
|
||||||
|
for i in range(0, len(PARAMETERS)):
|
||||||
|
key = PARAMETERS[i]
|
||||||
|
if key in params:
|
||||||
|
result.append(params[key])
|
||||||
|
else:
|
||||||
|
result.append(args[i])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def change_rank_limit(use_higher_ranks: bool):
|
||||||
|
mult = 2 if use_higher_ranks else 1
|
||||||
|
return {"maximum": 1024 * mult, "__type__": "update"}, {"maximum": 2048 * mult, "__type__": "update"}
|
||||||
|
|
||||||
|
|
||||||
def clean_path(base_path: str, path: str):
|
def clean_path(base_path: str, path: str):
|
||||||
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
||||||
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
|
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
|
||||||
@ -135,14 +152,23 @@ def clean_path(base_path: str, path: str):
|
|||||||
return f'{Path(base_path).absolute()}/{path}'
|
return f'{Path(base_path).absolute()}/{path}'
|
||||||
|
|
||||||
|
|
||||||
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
|
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, do_shuffle: bool, higher_rank_limit: bool, warmup_steps: int, optimizer: str):
|
||||||
cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, do_shuffle: bool, higher_rank_limit: bool):
|
|
||||||
|
if shared.args.monkey_patch:
|
||||||
|
from monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_gptq_lora_model
|
||||||
|
replace_peft_model_with_gptq_lora_model()
|
||||||
|
|
||||||
global WANT_INTERRUPT
|
global WANT_INTERRUPT
|
||||||
WANT_INTERRUPT = False
|
WANT_INTERRUPT = False
|
||||||
|
|
||||||
# == Input validation / processing ==
|
# == Input validation / processing ==
|
||||||
yield "Prepping..."
|
yield "Prepping..."
|
||||||
lora_file_path = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}"
|
lora_file_path = clean_path(None, lora_name)
|
||||||
|
if lora_file_path.strip() == '':
|
||||||
|
yield "Missing or invalid LoRA file name input."
|
||||||
|
return
|
||||||
|
|
||||||
|
lora_file_path = f"{shared.args.lora_dir}/{lora_file_path}"
|
||||||
actual_lr = float(learning_rate)
|
actual_lr = float(learning_rate)
|
||||||
model_type = type(shared.model).__name__
|
model_type = type(shared.model).__name__
|
||||||
|
|
||||||
@ -158,11 +184,11 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, and GPT-J models. (Found model type: {model_type})")
|
print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, and GPT-J models. (Found model type: {model_type})")
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
if shared.args.wbits > 0:
|
if shared.args.wbits > 0 and not shared.args.monkey_patch:
|
||||||
yield "LoRA training does not yet support 4bit. Please use `--load-in-8bit` for now."
|
yield "LoRA training in 4-bit requires loading with `--monkey-patch`"
|
||||||
return
|
return
|
||||||
|
|
||||||
elif not shared.args.load_in_8bit:
|
elif not shared.args.load_in_8bit and shared.args.wbits <= 0:
|
||||||
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
|
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
|
||||||
print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
|
print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
|
||||||
time.sleep(2) # Give it a moment for the message to show in UI before continuing
|
time.sleep(2) # Give it a moment for the message to show in UI before continuing
|
||||||
@ -172,7 +198,6 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
return
|
return
|
||||||
|
|
||||||
gradient_accumulation_steps = batch_size // micro_batch_size
|
gradient_accumulation_steps = batch_size // micro_batch_size
|
||||||
CURRENT_GRADIENT_ACCUM = gradient_accumulation_steps
|
|
||||||
shared.tokenizer.pad_token = 0
|
shared.tokenizer.pad_token = 0
|
||||||
shared.tokenizer.padding_side = "left"
|
shared.tokenizer.padding_side = "left"
|
||||||
|
|
||||||
@ -260,30 +285,41 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
print("Creating LoRA model...")
|
||||||
|
lora_model = get_peft_model(shared.model, config)
|
||||||
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
|
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
|
||||||
print("Loading existing LoRA file...")
|
print("Loading existing LoRA data...")
|
||||||
lora_model = PeftModel.from_pretrained(shared.model, lora_file_path)
|
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
|
||||||
else:
|
set_peft_model_state_dict(lora_model, state_dict_peft)
|
||||||
print("Creating new LoRA model...")
|
|
||||||
lora_model = get_peft_model(shared.model, config)
|
|
||||||
except:
|
except:
|
||||||
yield traceback.format_exc()
|
yield traceback.format_exc()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if shared.args.monkey_patch:
|
||||||
|
for n, m in lora_model.named_modules():
|
||||||
|
if '4bit' in str(type(m)):
|
||||||
|
if m.is_v1_model:
|
||||||
|
m.zeros = m.zeros.half()
|
||||||
|
m.scales = m.scales.half()
|
||||||
|
|
||||||
class Tracked():
|
class Tracked():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.current_steps = 0
|
self.current_steps = 0
|
||||||
self.max_steps = 0
|
self.max_steps = 0
|
||||||
|
self.did_save = False
|
||||||
|
|
||||||
tracked = Tracked()
|
tracked = Tracked()
|
||||||
|
actual_save_steps = math.ceil(save_steps / gradient_accumulation_steps)
|
||||||
|
|
||||||
class Callbacks(transformers.TrainerCallback):
|
class Callbacks(transformers.TrainerCallback):
|
||||||
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||||
tracked.current_steps = state.global_step * CURRENT_GRADIENT_ACCUM
|
tracked.current_steps = state.global_step * gradient_accumulation_steps
|
||||||
tracked.max_steps = state.max_steps * CURRENT_GRADIENT_ACCUM
|
tracked.max_steps = state.max_steps * gradient_accumulation_steps
|
||||||
if WANT_INTERRUPT:
|
if WANT_INTERRUPT:
|
||||||
control.should_epoch_stop = True
|
control.should_epoch_stop = True
|
||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
|
elif state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0:
|
||||||
|
lora_model.save_pretrained(f"{lora_file_path}/checkpoint-{tracked.current_steps}/")
|
||||||
|
|
||||||
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||||
tracked.current_steps += 1
|
tracked.current_steps += 1
|
||||||
@ -298,16 +334,17 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
args=transformers.TrainingArguments(
|
args=transformers.TrainingArguments(
|
||||||
per_device_train_batch_size=micro_batch_size,
|
per_device_train_batch_size=micro_batch_size,
|
||||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
warmup_steps=100,
|
warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps),
|
||||||
num_train_epochs=epochs,
|
num_train_epochs=epochs,
|
||||||
learning_rate=actual_lr,
|
learning_rate=actual_lr,
|
||||||
fp16=False if shared.args.cpu else True,
|
fp16=False if shared.args.cpu else True,
|
||||||
|
optim=optimizer,
|
||||||
logging_steps=5,
|
logging_steps=5,
|
||||||
evaluation_strategy="steps" if eval_data is not None else "no",
|
evaluation_strategy="steps" if eval_data is not None else "no",
|
||||||
eval_steps=math.ceil(eval_steps / gradient_accumulation_steps) if eval_data is not None else None,
|
eval_steps=math.ceil(eval_steps / gradient_accumulation_steps) if eval_data is not None else None,
|
||||||
save_strategy="steps",
|
save_strategy="no",
|
||||||
save_steps=math.ceil(save_steps / gradient_accumulation_steps),
|
|
||||||
output_dir=lora_file_path,
|
output_dir=lora_file_path,
|
||||||
|
lr_scheduler_type=lr_scheduler_type,
|
||||||
load_best_model_at_end=True if eval_data is not None else False,
|
load_best_model_at_end=True if eval_data is not None else False,
|
||||||
# TODO: Enable multi-device support
|
# TODO: Enable multi-device support
|
||||||
ddp_find_unused_parameters=None,
|
ddp_find_unused_parameters=None,
|
||||||
@ -318,10 +355,6 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
)
|
)
|
||||||
|
|
||||||
lora_model.config.use_cache = False
|
lora_model.config.use_cache = False
|
||||||
old_state_dict = lora_model.state_dict
|
|
||||||
lora_model.state_dict = (
|
|
||||||
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
|
||||||
).__get__(lora_model, type(lora_model))
|
|
||||||
|
|
||||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||||
lora_model = torch.compile(lora_model)
|
lora_model = torch.compile(lora_model)
|
||||||
@ -340,6 +373,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
def threaded_run():
|
def threaded_run():
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
||||||
|
lora_model.save_pretrained(lora_file_path)
|
||||||
|
print("LoRA training run is completed and saved.")
|
||||||
|
tracked.did_save = True
|
||||||
|
|
||||||
thread = threading.Thread(target=threaded_run)
|
thread = threading.Thread(target=threaded_run)
|
||||||
thread.start()
|
thread.start()
|
||||||
@ -368,8 +405,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
|
yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
|
||||||
|
|
||||||
print("Training complete, saving...")
|
# Saving in the train thread might fail if an error occurs, so save here if so.
|
||||||
lora_model.save_pretrained(lora_file_path)
|
if not tracked.did_save:
|
||||||
|
print("Training complete, saving...")
|
||||||
|
lora_model.save_pretrained(lora_file_path)
|
||||||
|
|
||||||
if WANT_INTERRUPT:
|
if WANT_INTERRUPT:
|
||||||
print("Training interrupted.")
|
print("Training interrupted.")
|
||||||
|
Loading…
Reference in New Issue
Block a user