diff --git a/modules/training.py b/modules/training.py index 2830ba07..0410ddd1 100644 --- a/modules/training.py +++ b/modules/training.py @@ -605,6 +605,11 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en: control.should_training_stop = True print(f"\033[1;31;1mStop Loss {stop_at_loss} reached.\033[0;37;0m") + # Fix training for mixed precision models + for param in shared.model.parameters(): + if param.requires_grad: + param.data = param.data.float() + trainer = transformers.Trainer( model=lora_model, train_dataset=train_data,