diff --git a/modules/training.py b/modules/training.py index 3a9b4146..7be0d24f 100644 --- a/modules/training.py +++ b/modules/training.py @@ -305,15 +305,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch time.sleep(5) - if shared.args.wbits > 0 and not shared.args.monkey_patch: - yield "LoRA training with GPTQ models requires loading with `--monkey-patch`" + if shared.args.loader == 'GPTQ-for-LLaMa' and not shared.args.monkey_patch: + yield "LoRA training with GPTQ-for-LLaMa requires loading with `--monkey-patch`" return - elif not (shared.args.load_in_8bit or shared.args.load_in_4bit) and shared.args.wbits <= 0: - yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*" - logger.warning("It is highly recommended you use `--load-in-8bit` for LoRA training.") - time.sleep(2) # Give it a moment for the message to show in UI before continuing - if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0: yield "Cannot input zeroes." return