fix: num training steps for lr decay

This commit is contained in:
Zach Nussbaum 2023-04-10 02:15:31 +00:00
parent 195f8a7d4e
commit 9dfd8e1a7c

View File

@ -100,7 +100,7 @@ def train(accelerator, config):
name="cosine", name="cosine",
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=config["warmup_steps"] * accelerator.num_processes, num_warmup_steps=config["warmup_steps"] * accelerator.num_processes,
num_training_steps=total_num_steps * accelerator.num_processes, num_training_steps=total_num_steps,
) )
else: else:
scheduler = DummyScheduler( scheduler = DummyScheduler(