fix: add epoch train

This commit is contained in:
Zach Nussbaum 2023-03-27 16:32:35 +00:00
parent bb28929305
commit 24765a1965

View File

@ -115,48 +115,56 @@ def train(accelerator, config):
"gradient_accumulation_steps"
]
for step, batch in enumerate(tqdm(train_dataloader)):
model.train()
outputs = model(**batch)
loss = outputs.loss
loss = loss / gradient_accumulation_steps
for epoch in range(config["num_epochs"]):
for step, batch in enumerate(tqdm(train_dataloader)):
model.train()
outputs = model(**batch)
loss = outputs.loss
loss = loss / gradient_accumulation_steps
accelerator.backward(loss)
accelerator.backward(loss)
# log LR in case something weird happens
if step > 0 and step % (config["eval_every"] // 10) == 0:
if config["wandb"]:
accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=step)
# log LR in case something weird happens
if step > 0 and step % (config["eval_every"] // 10) == 0:
if config["wandb"]:
accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=step)
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
loss_values = accelerator.gather_for_metrics({"loss": loss.detach()})
train_loss.update(loss_values["loss"])
loss_values = accelerator.gather_for_metrics({"loss": loss.detach()})
train_loss.update(loss_values["loss"])
if step > 0 and step % config["save_every"] == 0:
accelerator.save_state(f"{config['output_dir']}/step_{step}")
if step > 0 and step % config["save_every"] == 0:
accelerator.save_state(f"{config['output_dir']}/step_{step}")
if step > 0 and step % config["eval_every"] == 0:
val_loss = evaluate(config, model, val_dataloader)
if step > 0 and step % config["eval_every"] == 0:
val_loss = evaluate(config, model, val_dataloader)
log_train = {
"train_loss": train_loss.compute()
log_train = {
"train_loss": train_loss.compute()
}
log_val = {
"val_loss": val_loss.compute()
}
log_val = {
"val_loss": val_loss.compute()
}
if config["wandb"]:
accelerator.log({**log_train, **log_val}, step=step)
if config["wandb"]:
accelerator.log({**log_train, **log_val}, step=step)
accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}")
accelerator.print(format_metrics(log_train, "train", f" step {step} "))
accelerator.print(format_metrics(log_val, "val", f" step {step} "))
accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}")
accelerator.print(format_metrics(log_train, "train", f" step {step} "))
accelerator.print(format_metrics(log_val, "val", f" step {step} "))
train_loss.reset()
train_loss.reset()
accelerator.print(f"Epoch {epoch} finished")
accelerator.print(f"Pushing to HF hub")
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
if accelerator.is_main_process:
unwrapped_model.push_to_hub(config["save_name"], private=True)
accelerator.wait_for_everyone()
@ -168,7 +176,8 @@ def train(accelerator, config):
state_dict=accelerator.get_state_dict(model),
)
unwrapped_model.push_to_hub(config["save_name"], private=True)
if accelerator.is_main_process:
unwrapped_model.push_to_hub(config["save_name"], private=True)
accelerator.end_training()