mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
fix: update train scripts and configs for other models (#1164)
* feat: falcon config * feat: mpt config * chore: gitignore * refactor: step calculation * fix: attention mask + shuffle on epoch end * fix: return tensors * fix: wait for everyone * chore: config * chore: ds config * fix: remove ccols * fix: logging and saving * chore: add einops
This commit is contained in:
parent
e8b19b8e82
commit
6c4f449b7a
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,3 +1,6 @@
|
|||||||
|
*.arrow
|
||||||
|
squad_*
|
||||||
|
*sbert_embedded*
|
||||||
*.pkl
|
*.pkl
|
||||||
ckpts*
|
ckpts*
|
||||||
.deepspeed_env
|
.deepspeed_env
|
||||||
|
49
gpt4all-training/configs/deepspeed/ds_config_mpt.json
Normal file
49
gpt4all-training/configs/deepspeed/ds_config_mpt.json
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
{
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"min_loss_scale": 1,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"initial_scale_power": 32
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
|
"gradient_clipping": 1.0,
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 1,
|
||||||
|
"offload_param": {
|
||||||
|
"device": "none"
|
||||||
|
},
|
||||||
|
"offload_optimizer": {
|
||||||
|
"device": "none"
|
||||||
|
},
|
||||||
|
"allgather_partitions": true,
|
||||||
|
"allgather_bucket_size": 5e8,
|
||||||
|
"contiguous_gradients": true
|
||||||
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"params": {
|
||||||
|
"lr": "auto",
|
||||||
|
"betas": [
|
||||||
|
0.9,
|
||||||
|
0.999
|
||||||
|
],
|
||||||
|
"eps": 1e-08
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupDecayLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": 0,
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto",
|
||||||
|
"warmup_type": "linear",
|
||||||
|
"total_num_steps": "auto"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
48
gpt4all-training/configs/deepspeed/ds_config_pythia.json
Normal file
48
gpt4all-training/configs/deepspeed/ds_config_pythia.json
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
{
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"min_loss_scale": 1,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"initial_scale_power": 32
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
|
"gradient_clipping": 1.0,
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 2,
|
||||||
|
"offload_param": {
|
||||||
|
"device": "none"
|
||||||
|
},
|
||||||
|
"offload_optimizer": {
|
||||||
|
"device": "none"
|
||||||
|
},
|
||||||
|
"allgather_partitions": true,
|
||||||
|
"allgather_bucket_size": 5e8,
|
||||||
|
"contiguous_gradients": true
|
||||||
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"params": {
|
||||||
|
"lr": "auto",
|
||||||
|
"betas": [
|
||||||
|
0.9,
|
||||||
|
0.999
|
||||||
|
],
|
||||||
|
"eps": 1e-08
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": 0,
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto",
|
||||||
|
"warmup_type": "linear"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
34
gpt4all-training/configs/train/finetune_falcon.yaml
Normal file
34
gpt4all-training/configs/train/finetune_falcon.yaml
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# model/tokenizer
|
||||||
|
model_name: "tiiuae/falcon-7b"
|
||||||
|
tokenizer_name: "tiiuae/falcon-7b"
|
||||||
|
gradient_checkpointing: true
|
||||||
|
save_name: "nomic-ai/gpt4all-falcon"
|
||||||
|
|
||||||
|
# dataset
|
||||||
|
streaming: false
|
||||||
|
num_proc: 64
|
||||||
|
dataset_path: "nomic-ai/gpt4all-j-prompt-generations"
|
||||||
|
revision: "v1.3-groovy"
|
||||||
|
max_length: 1024
|
||||||
|
batch_size: 32
|
||||||
|
|
||||||
|
# train dynamics
|
||||||
|
lr: 2.0e-5
|
||||||
|
min_lr: 0
|
||||||
|
weight_decay: 0.0
|
||||||
|
eval_every: 500
|
||||||
|
eval_steps: 105
|
||||||
|
save_every: 1000
|
||||||
|
log_grads_every: 500
|
||||||
|
output_dir: "ckpts/falcon"
|
||||||
|
checkpoint: "/home/paperspace/gpt4all/ckpts/mpt/step_1000"
|
||||||
|
lora: false
|
||||||
|
warmup_steps: 500
|
||||||
|
num_epochs: 2
|
||||||
|
|
||||||
|
# logging
|
||||||
|
wandb: true
|
||||||
|
wandb_entity: "gpt4all"
|
||||||
|
wandb_project_name: "gpt4all"
|
||||||
|
seed: 42
|
||||||
|
|
34
gpt4all-training/configs/train/finetune_mpt.yaml
Normal file
34
gpt4all-training/configs/train/finetune_mpt.yaml
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# model/tokenizer
|
||||||
|
model_name: "mosaicml/mpt-7b"
|
||||||
|
tokenizer_name: "mosaicml/mpt-7b"
|
||||||
|
gradient_checkpointing: false
|
||||||
|
save_name: "nomic-ai/mpt-finetuned-round2"
|
||||||
|
|
||||||
|
# dataset
|
||||||
|
streaming: false
|
||||||
|
num_proc: 64
|
||||||
|
dataset_path: "nomic-ai/gpt4all-j-prompt-generations"
|
||||||
|
revision: "v1.3-groovy"
|
||||||
|
max_length: 1024
|
||||||
|
batch_size: 8
|
||||||
|
|
||||||
|
# train dynamics
|
||||||
|
lr: 2.0e-5
|
||||||
|
min_lr: 0
|
||||||
|
weight_decay: 0.0
|
||||||
|
eval_every: 500
|
||||||
|
eval_steps: 105
|
||||||
|
save_every: 1000
|
||||||
|
log_grads_every: 500
|
||||||
|
output_dir: "ckpts/mpt"
|
||||||
|
checkpoint: null
|
||||||
|
lora: false
|
||||||
|
warmup_steps: 500
|
||||||
|
num_epochs: 2
|
||||||
|
|
||||||
|
# logging
|
||||||
|
wandb: false
|
||||||
|
wandb_entity: "gpt4all"
|
||||||
|
wandb_project_name: "gpt4all"
|
||||||
|
seed: 42
|
||||||
|
|
34
gpt4all-training/configs/train/finetune_openllama.yaml
Normal file
34
gpt4all-training/configs/train/finetune_openllama.yaml
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# model/tokenizer
|
||||||
|
model_name: "openlm-research/open_llama_7b"
|
||||||
|
tokenizer_name: "openlm-research/open_llama_7b"
|
||||||
|
gradient_checkpointing: true
|
||||||
|
save_name: "nomic-ai/gpt4all-openllama"
|
||||||
|
|
||||||
|
# dataset
|
||||||
|
streaming: false
|
||||||
|
num_proc: 64
|
||||||
|
dataset_path: "nomic-ai/gpt4all-updated"
|
||||||
|
revision: null
|
||||||
|
max_length: 1024
|
||||||
|
batch_size: 32
|
||||||
|
|
||||||
|
# train dynamics
|
||||||
|
lr: 2.0e-5
|
||||||
|
min_lr: 0
|
||||||
|
weight_decay: 0.0
|
||||||
|
eval_every: 500
|
||||||
|
log_every: 10
|
||||||
|
save_every: 1000
|
||||||
|
log_grads_every: 500
|
||||||
|
output_dir: "ckpts/falcon"
|
||||||
|
checkpoint: null
|
||||||
|
lora: false
|
||||||
|
warmup_steps: 500
|
||||||
|
num_epochs: 3
|
||||||
|
|
||||||
|
# logging
|
||||||
|
wandb: true
|
||||||
|
wandb_entity: "gpt4all"
|
||||||
|
wandb_project_name: "gpt4all"
|
||||||
|
seed: 42
|
||||||
|
|
@ -12,7 +12,7 @@ def tokenize_inputs(config, tokenizer, examples):
|
|||||||
|
|
||||||
# hacky backward compatible
|
# hacky backward compatible
|
||||||
different_eos = tokenizer.eos_token != "</s>"
|
different_eos = tokenizer.eos_token != "</s>"
|
||||||
out = {"labels": [], "input_ids": []}
|
out = {"labels": [], "input_ids": [], "attention_mask": []}
|
||||||
for prompt, response in zip(examples["prompt"], examples["response"]):
|
for prompt, response in zip(examples["prompt"], examples["response"]):
|
||||||
if different_eos:
|
if different_eos:
|
||||||
if response.count("</s> \n") > 0:
|
if response.count("</s> \n") > 0:
|
||||||
@ -49,9 +49,10 @@ def tokenize_inputs(config, tokenizer, examples):
|
|||||||
print(response)
|
print(response)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"]
|
padded = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length, return_tensors="pt")
|
||||||
out["labels"].append(labels)
|
out["labels"].append(labels)
|
||||||
out["input_ids"].append(input_tokens)
|
out["input_ids"].append(padded["input_ids"])
|
||||||
|
out["attention_mask"].append(padded["attention_mask"])
|
||||||
|
|
||||||
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
|
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
|
||||||
|
|
||||||
@ -72,7 +73,7 @@ def load_data(config, tokenizer):
|
|||||||
dataset = load_dataset("json", data_files=files, split="train")
|
dataset = load_dataset("json", data_files=files, split="train")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
dataset = load_dataset(dataset_path, split="train")
|
dataset = load_dataset(dataset_path, split="train", revision=config["revision"] if "revision" in config else None)
|
||||||
|
|
||||||
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
|
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
|
||||||
|
|
||||||
@ -83,19 +84,23 @@ def load_data(config, tokenizer):
|
|||||||
else:
|
else:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
|
cols_to_keep = ["input_ids", "labels", "attention_mask"]
|
||||||
# tokenize inputs and return labels and attention mask
|
# tokenize inputs and return labels and attention mask
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
||||||
batched=True,
|
batched=True,
|
||||||
remove_columns=["source", "prompt"],
|
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
remove_cols = [col for col in train_dataset.column_names if col not in cols_to_keep]
|
||||||
|
train_dataset = train_dataset.remove_columns(remove_cols)
|
||||||
|
|
||||||
val_dataset = val_dataset.map(
|
val_dataset = val_dataset.map(
|
||||||
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
||||||
batched=True,
|
batched=True,
|
||||||
remove_columns=["source", "prompt"],
|
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
remove_cols = [col for col in val_dataset.column_names if col not in cols_to_keep]
|
||||||
|
val_dataset = val_dataset.remove_columns(remove_cols)
|
||||||
|
|
||||||
train_dataset = train_dataset.with_format("torch")
|
train_dataset = train_dataset.with_format("torch")
|
||||||
val_dataset = val_dataset.with_format("torch")
|
val_dataset = val_dataset.with_format("torch")
|
||||||
@ -106,12 +111,14 @@ def load_data(config, tokenizer):
|
|||||||
train_dataset,
|
train_dataset,
|
||||||
collate_fn=DefaultDataCollator(),
|
collate_fn=DefaultDataCollator(),
|
||||||
batch_size=config["batch_size"],
|
batch_size=config["batch_size"],
|
||||||
|
shuffle=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
val_dataloader = DataLoader(
|
val_dataloader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
collate_fn=DefaultDataCollator(),
|
collate_fn=DefaultDataCollator(),
|
||||||
batch_size=config["batch_size"],
|
batch_size=config["batch_size"],
|
||||||
|
shuffle=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return train_dataloader, val_dataloader
|
return train_dataloader, val_dataloader
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
accelerate
|
accelerate
|
||||||
datasets
|
datasets
|
||||||
|
einops
|
||||||
torchmetrics
|
torchmetrics
|
||||||
evaluate
|
evaluate
|
||||||
transformers>=4.28.0
|
transformers>=4.28.0
|
||||||
wandb
|
wandb
|
||||||
pip
|
|
||||||
peft
|
peft
|
||||||
nodelist-inflator
|
nodelist-inflator
|
||||||
deepspeed
|
deepspeed
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, LlamaForCausalLM
|
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
@ -42,7 +42,7 @@ def train(accelerator, config):
|
|||||||
accelerator.print(config)
|
accelerator.print(config)
|
||||||
accelerator.print(f"Using {accelerator.num_processes} GPUs")
|
accelerator.print(f"Using {accelerator.num_processes} GPUs")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
|
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'], use_fast=False)
|
||||||
# if no pad token, set it to eos
|
# if no pad token, set it to eos
|
||||||
if tokenizer.pad_token is None:
|
if tokenizer.pad_token is None:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
@ -53,6 +53,7 @@ def train(accelerator, config):
|
|||||||
|
|
||||||
|
|
||||||
checkpoint = config["gradient_checkpointing"]
|
checkpoint = config["gradient_checkpointing"]
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
||||||
use_cache=False if checkpoint else True,
|
use_cache=False if checkpoint else True,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
@ -86,7 +87,7 @@ def train(accelerator, config):
|
|||||||
# decay to min_lr instead of 0
|
# decay to min_lr instead of 0
|
||||||
lr_ratio = config["min_lr"] / config["lr"]
|
lr_ratio = config["min_lr"] / config["lr"]
|
||||||
accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}")
|
accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}")
|
||||||
total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"]
|
total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * (config["num_epochs"])
|
||||||
# instead of decaying to zero, decay to ratio of min_lr / lr
|
# instead of decaying to zero, decay to ratio of min_lr / lr
|
||||||
total_num_steps += int(total_num_steps * lr_ratio) + config["warmup_steps"]
|
total_num_steps += int(total_num_steps * lr_ratio) + config["warmup_steps"]
|
||||||
accelerator.print(f"Total training steps: {total_num_steps}")
|
accelerator.print(f"Total training steps: {total_num_steps}")
|
||||||
@ -104,7 +105,7 @@ def train(accelerator, config):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
scheduler = DummyScheduler(
|
scheduler = DummyScheduler(
|
||||||
optimizer, total_num_steps=config["warmup_steps"], warmup_num_steps=config["warmup_steps"]
|
optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"]
|
||||||
)
|
)
|
||||||
|
|
||||||
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
|
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
|
||||||
@ -117,26 +118,34 @@ def train(accelerator, config):
|
|||||||
if config["checkpoint"]:
|
if config["checkpoint"]:
|
||||||
accelerator.load_state(config["checkpoint"])
|
accelerator.load_state(config["checkpoint"])
|
||||||
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
|
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
|
||||||
path = os.path.basename(config["train_args"]["resume_from_checkpoint"])
|
path = os.path.basename(config["checkpoint"])
|
||||||
training_difference = os.path.splitext(path)[0]
|
training_difference = os.path.splitext(path)[0]
|
||||||
resume_step = int(training_difference.replace("step_", ""))
|
resume_step = int(training_difference.replace("step_", ""))
|
||||||
accelerator.skip_first_batches(train_dataloader, resume_step)
|
train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
|
||||||
accelerator.print(f"Resuming from step {resume_step}")
|
accelerator.print(f"Resuming from step {resume_step}")
|
||||||
|
else:
|
||||||
|
resume_step = 0
|
||||||
|
|
||||||
|
|
||||||
# log gradients
|
# log gradients
|
||||||
if accelerator.is_main_process and config["wandb"]:
|
if accelerator.is_main_process and config["wandb"]:
|
||||||
wandb.watch(model, log_freq=config["log_grads_every"], log="all")
|
wandb.watch(model, log_freq=config["log_grads_every"], log="all")
|
||||||
|
|
||||||
for epoch in range(config["num_epochs"]):
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
for epoch in range(0, config["num_epochs"]):
|
||||||
train_loss = MeanMetric(nan_strategy="error").to(model.device)
|
train_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||||
for step, batch in enumerate(tqdm(train_dataloader)):
|
for step, batch in enumerate(tqdm(train_dataloader)):
|
||||||
|
curr_step = epoch * len(train_dataloader) + step
|
||||||
model.train()
|
model.train()
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
|
|
||||||
# gather loss before backprop in case of gradient accumulation
|
# gather loss before backprop in case of gradient accumulation
|
||||||
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
|
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
|
||||||
|
if config["wandb"]:
|
||||||
|
accelerator.log({"loss": torch.mean(loss_values["loss"]).item()}, step=curr_step)
|
||||||
train_loss.update(loss_values["loss"])
|
train_loss.update(loss_values["loss"])
|
||||||
|
|
||||||
loss = loss / gradient_accumulation_steps
|
loss = loss / gradient_accumulation_steps
|
||||||
@ -144,9 +153,8 @@ def train(accelerator, config):
|
|||||||
# get gradient norm of all params
|
# get gradient norm of all params
|
||||||
|
|
||||||
# log LR in case something weird happens
|
# log LR in case something weird happens
|
||||||
if step > 0 and step % (config["eval_every"] // 10) == 0:
|
if step > 0 and step % (config["log_lr_every"]) == 0:
|
||||||
if config["wandb"]:
|
if config["wandb"]:
|
||||||
curr_step = step + epoch * len(train_dataloader)
|
|
||||||
accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step)
|
accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step)
|
||||||
|
|
||||||
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||||
@ -156,7 +164,6 @@ def train(accelerator, config):
|
|||||||
|
|
||||||
|
|
||||||
if step > 0 and step % config["save_every"] == 0:
|
if step > 0 and step % config["save_every"] == 0:
|
||||||
curr_step = step + epoch * len(train_dataloader)
|
|
||||||
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
||||||
|
|
||||||
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
||||||
@ -170,7 +177,6 @@ def train(accelerator, config):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if config["wandb"]:
|
if config["wandb"]:
|
||||||
curr_step = step + epoch * len(train_dataloader)
|
|
||||||
accelerator.log({**log_train, **log_val}, step=curr_step)
|
accelerator.log({**log_train, **log_val}, step=curr_step)
|
||||||
|
|
||||||
accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}")
|
accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}")
|
||||||
@ -181,8 +187,14 @@ def train(accelerator, config):
|
|||||||
|
|
||||||
accelerator.print(f"Epoch {epoch} finished")
|
accelerator.print(f"Epoch {epoch} finished")
|
||||||
accelerator.print(f"Pushing to HF hub")
|
accelerator.print(f"Pushing to HF hub")
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
unwrapped_model = accelerator.unwrap_model(model)
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
|
||||||
|
unwrapped_model.save_pretrained(
|
||||||
|
f"{config['output_dir']}/epoch_{epoch}",
|
||||||
|
is_main_process=accelerator.is_main_process,
|
||||||
|
save_function=accelerator.save,
|
||||||
|
state_dict=accelerator.get_state_dict(model),
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
|
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
|
||||||
@ -191,21 +203,16 @@ def train(accelerator, config):
|
|||||||
accelerator.print(e)
|
accelerator.print(e)
|
||||||
accelerator.print(f"Failed to push to hub")
|
accelerator.print(f"Failed to push to hub")
|
||||||
|
|
||||||
|
|
||||||
|
if config["num_epochs"] > 1:
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
unwrapped_model.save_pretrained(
|
unwrapped_model.save_pretrained(
|
||||||
f"{config['output_dir']}/epoch_{epoch}",
|
f"{config['output_dir']}/final",
|
||||||
is_main_process=accelerator.is_main_process,
|
is_main_process=accelerator.is_main_process,
|
||||||
save_function=accelerator.save,
|
save_function=accelerator.save,
|
||||||
state_dict=accelerator.get_state_dict(model),
|
state_dict=accelerator.get_state_dict(model),
|
||||||
)
|
)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
unwrapped_model = accelerator.unwrap_model(model)
|
|
||||||
unwrapped_model.save_pretrained(
|
|
||||||
f"{config['output_dir']}/final",
|
|
||||||
is_main_process=accelerator.is_main_process,
|
|
||||||
save_function=accelerator.save,
|
|
||||||
state_dict=accelerator.get_state_dict(model),
|
|
||||||
)
|
|
||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user