diff --git a/finetune.py b/finetune.py index 02f6b9d..2f5fe21 100644 --- a/finetune.py +++ b/finetune.py @@ -18,6 +18,7 @@ from peft import ( LoraConfig, get_peft_model, get_peft_model_state_dict, + set_peft_model_state_dict, ) @@ -43,7 +44,8 @@ def train( ], # llm hyperparams train_on_inputs: bool = True, # if False, masks out inputs in loss - group_by_length: bool = False, # faster, but produces an odd training loss curve + group_by_length: bool = False, # faster, but produces an odd training loss curve, + resume_from_checkpoint: str = None, # either training checkpoint or final adapter ): print( f"Training Alpaca-LoRA model with params:\n" @@ -62,6 +64,7 @@ def train( f"lora_target_modules: {lora_target_modules}\n" f"train_on_inputs: {train_on_inputs}\n" f"group_by_length: {group_by_length}\n" + f"resume_from_checkpoint: {resume_from_checkpoint}\n" ) assert ( base_model @@ -137,6 +140,24 @@ def train( data = load_dataset("json", data_files=data_path) + if resume_from_checkpoint: + # Check the available weights and load them + checkpoint_name = os.path.join( + resume_from_checkpoint, "pytorch_model.bin" + ) # Full checkpoint + if not os.path.exists(checkpoint_name): + checkpoint_name = os.path.join( + resume_from_checkpoint, "adapter_model.bin" + ) # only LoRA model - LoRA config above has to fit + resume_from_checkpoint = False # So the trainer won't try loading its state + # The two files above have a different name depending on how they were saved, but are actually the same. + if os.path.exists(checkpoint_name): + print(f"Restarting from {checkpoint_name}") + adapters_weights = torch.load(checkpoint_name) + model = set_peft_model_state_dict(model, adapters_weights) + + model.print_trainable_parameters() # Be more transparent about the % of trainable params. + if val_set_size > 0: train_val = data["train"].train_test_split( test_size=val_set_size, shuffle=True, seed=42 @@ -183,7 +204,7 @@ def train( if torch.__version__ >= "2" and sys.platform != "win32": model = torch.compile(model) - trainer.train() + trainer.train(resume_from_checkpoint=resume_from_checkpoint) model.save_pretrained(output_dir)