mirror of
https://github.com/tloen/alpaca-lora.git
synced 2024-10-01 01:05:56 -04:00
resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
This commit is contained in:
parent
b948f892ba
commit
da6b427a08
25
finetune.py
25
finetune.py
@ -18,6 +18,7 @@ from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
get_peft_model_state_dict,
|
||||
set_peft_model_state_dict,
|
||||
)
|
||||
|
||||
|
||||
@ -43,7 +44,8 @@ def train(
|
||||
],
|
||||
# llm hyperparams
|
||||
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
||||
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
||||
group_by_length: bool = False, # faster, but produces an odd training loss curve,
|
||||
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
||||
):
|
||||
print(
|
||||
f"Training Alpaca-LoRA model with params:\n"
|
||||
@ -62,6 +64,7 @@ def train(
|
||||
f"lora_target_modules: {lora_target_modules}\n"
|
||||
f"train_on_inputs: {train_on_inputs}\n"
|
||||
f"group_by_length: {group_by_length}\n"
|
||||
f"resume_from_checkpoint: {resume_from_checkpoint}\n"
|
||||
)
|
||||
assert (
|
||||
base_model
|
||||
@ -137,6 +140,24 @@ def train(
|
||||
|
||||
data = load_dataset("json", data_files=data_path)
|
||||
|
||||
if resume_from_checkpoint:
|
||||
# Check the available weights and load them
|
||||
checkpoint_name = os.path.join(
|
||||
resume_from_checkpoint, "pytorch_model.bin"
|
||||
) # Full checkpoint
|
||||
if not os.path.exists(checkpoint_name):
|
||||
checkpoint_name = os.path.join(
|
||||
resume_from_checkpoint, "adapter_model.bin"
|
||||
) # only LoRA model - LoRA config above has to fit
|
||||
resume_from_checkpoint = False # So the trainer won't try loading its state
|
||||
# The two files above have a different name depending on how they were saved, but are actually the same.
|
||||
if os.path.exists(checkpoint_name):
|
||||
print(f"Restarting from {checkpoint_name}")
|
||||
adapters_weights = torch.load(checkpoint_name)
|
||||
model = set_peft_model_state_dict(model, adapters_weights)
|
||||
|
||||
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
|
||||
|
||||
if val_set_size > 0:
|
||||
train_val = data["train"].train_test_split(
|
||||
test_size=val_set_size, shuffle=True, seed=42
|
||||
@ -183,7 +204,7 @@ def train(
|
||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||
model = torch.compile(model)
|
||||
|
||||
trainer.train()
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
model.save_pretrained(output_dir)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user