Training update - backup the existing adapter before training on top of it (#2902)

This commit is contained in:
FartyPants 2023-06-27 17:24:04 -04:00 committed by GitHub
parent 40bbd53640
commit ab1998146b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,6 +10,10 @@ from pathlib import Path
import gradio as gr import gradio as gr
import torch import torch
import transformers import transformers
import shutil
from datetime import datetime
from datasets import Dataset, load_dataset from datasets import Dataset, load_dataset
from peft import ( from peft import (
LoraConfig, LoraConfig,
@ -208,6 +212,35 @@ def clean_path(base_path: str, path: str):
return f'{Path(base_path).absolute()}/{path}' return f'{Path(base_path).absolute()}/{path}'
def backup_adapter(input_folder):
# Get the creation date of the file adapter_model.bin
try:
adapter_file = Path(f"{input_folder}/adapter_model.bin")
if adapter_file.is_file():
logger.info("Backing up existing LoRA adapter...")
creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime)
creation_date_str = creation_date.strftime("Backup-%Y-%m-%d")
# Create the new subfolder
subfolder_path = Path(f"{input_folder}/{creation_date_str}")
subfolder_path.mkdir(parents=True, exist_ok=True)
# Check if the file already exists in the subfolder
backup_adapter_file = Path(f"{input_folder}/{creation_date_str}/adapter_model.bin")
if backup_adapter_file.is_file():
print(" - Backup already exists. Skipping backup process.")
return
# Copy existing files to the new subfolder
existing_files = Path(input_folder).iterdir()
for file in existing_files:
if file.is_file():
shutil.copy2(file, subfolder_path)
except Exception as e:
print("An error occurred in backup_adapter:", str(e))
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float): def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float):
if shared.args.monkey_patch: if shared.args.monkey_patch:
@ -394,6 +427,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
task_type="CAUSAL_LM" task_type="CAUSAL_LM"
) )
# == Backup the existing adapter ==
if not always_override:
backup_adapter(lora_file_path)
try: try:
logger.info("Creating LoRA model...") logger.info("Creating LoRA model...")
lora_model = get_peft_model(shared.model, config) lora_model = get_peft_model(shared.model, config)