fix: saving name

This commit is contained in:
Zach Nussbaum 2023-04-08 20:56:13 +00:00
parent 633df8edb4
commit 9efdf56e38

View File

@ -192,7 +192,7 @@ def train(accelerator, config):
accelerator.print(f"Failed to push to hub")
unwrapped_model.save_pretrained(
f"{config['output_dir']}/-epoch_{epoch}",
f"{config['output_dir']}/epoch_{epoch}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),