fix: only on first process, not once on every node

This commit is contained in:
Zach Nussbaum 2023-04-05 02:36:22 +00:00
parent d0402288bd
commit 97d4499d79

View File

@ -137,7 +137,7 @@ def train(accelerator, config):
# log gradients
if accelerator.is_local_main_process and config["wandb"]:
if accelerator.is_main_process and config["wandb"]:
wandb.watch(model, log_freq=config["log_grads_every"])
for epoch in range(config["num_epochs"]):