Merge pull request #3 from nomic-ai/train

log wandb multi-epoch
This commit is contained in:
Andriy Mulyar 2023-03-29 13:50:26 -04:00 committed by GitHub
commit 252676ff05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -127,7 +127,8 @@ def train(accelerator, config):
# log LR in case something weird happens # log LR in case something weird happens
if step > 0 and step % (config["eval_every"] // 10) == 0: if step > 0 and step % (config["eval_every"] // 10) == 0:
if config["wandb"]: if config["wandb"]:
accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=step) curr_step = step + epoch * len(train_dataloader)
accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step)
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step() optimizer.step()
@ -151,7 +152,8 @@ def train(accelerator, config):
} }
if config["wandb"]: if config["wandb"]:
accelerator.log({**log_train, **log_val}, step=step) curr_step = step + epoch * len(train_dataloader)
accelerator.log({**log_train, **log_val}, step=curr_step)
accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}") accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}")
accelerator.print(format_metrics(log_train, "train", f" step {step} ")) accelerator.print(format_metrics(log_train, "train", f" step {step} "))