resume_from_checkpoint

Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
This commit is contained in:
Eric Wang 2023-03-26 17:17:54 -07:00
parent b948f892ba
commit da6b427a08

View File

@ -18,6 +18,7 @@ from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
set_peft_model_state_dict,
)
@ -43,7 +44,8 @@ def train(
],
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
group_by_length: bool = False, # faster, but produces an odd training loss curve
group_by_length: bool = False, # faster, but produces an odd training loss curve,
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
):
print(
f"Training Alpaca-LoRA model with params:\n"
@ -62,6 +64,7 @@ def train(
f"lora_target_modules: {lora_target_modules}\n"
f"train_on_inputs: {train_on_inputs}\n"
f"group_by_length: {group_by_length}\n"
f"resume_from_checkpoint: {resume_from_checkpoint}\n"
)
assert (
base_model
@ -137,6 +140,24 @@ def train(
data = load_dataset("json", data_files=data_path)
if resume_from_checkpoint:
# Check the available weights and load them
checkpoint_name = os.path.join(
resume_from_checkpoint, "pytorch_model.bin"
) # Full checkpoint
if not os.path.exists(checkpoint_name):
checkpoint_name = os.path.join(
resume_from_checkpoint, "adapter_model.bin"
) # only LoRA model - LoRA config above has to fit
resume_from_checkpoint = False # So the trainer won't try loading its state
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
model = set_peft_model_state_dict(model, adapters_weights)
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
if val_set_size > 0:
train_val = data["train"].train_test_split(
test_size=val_set_size, shuffle=True, seed=42
@ -183,7 +204,7 @@ def train(
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
trainer.train()
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
model.save_pretrained(output_dir)