LoRA: Fix error "Attempting to unscale FP16 gradients" when training (#5268)

This commit is contained in:
ilya sheprut 2024-01-17 23:11:49 +03:00 committed by GitHub
parent 535ea9928a
commit 4d14eb8b82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -605,6 +605,11 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
control.should_training_stop = True
print(f"\033[1;31;1mStop Loss {stop_at_loss} reached.\033[0;37;0m")
# Fix training for mixed precision models
for param in shared.model.parameters():
if param.requires_grad:
param.data = param.data.float()
trainer = transformers.Trainer(
model=lora_model,
train_dataset=train_data,