diff --git a/modules/training.py b/modules/training.py index e2be18e8..0e210c52 100644 --- a/modules/training.py +++ b/modules/training.py @@ -90,7 +90,8 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le evalData = load_dataset("json", data_files=cleanPath('training/datasets', f'{evalDataset}.json')) evalData = evalData['train'].shuffle().map(generate_and_tokenize_prompt) # Start prepping the model itself - model = prepare_model_for_int8_training(model) + if not hasattr(model, 'lm_head') or hasattr(model.lm_head, 'weight'): + model = prepare_model_for_int8_training(model) config = LoraConfig( r=loraRank, lora_alpha=loraAlpha,