Fixed the param name when loading a LoRA using a model loaded in 4 or 8 bits (#3036)

This commit is contained in:
Fernando Tarin Morales 2023-07-07 14:24:07 +09:00 committed by GitHub
parent 1f540fa4f8
commit d7e14e1f78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -114,11 +114,12 @@ def add_lora_transformers(lora_names):
if len(lora_names) > 0:
params = {}
if not shared.args.cpu:
params['dtype'] = shared.model.dtype
if hasattr(shared.model, "hf_device_map"):
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
elif shared.args.load_in_8bit:
params['device_map'] = {'': 0}
if shared.args.load_in_4bit or shared.args.load_in_8bit:
params['peft_type'] = shared.model.dtype
else:
params['dtype'] = shared.model.dtype
if hasattr(shared.model, "hf_device_map"):
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_names[0]}"), adapter_name=lora_names[0], **params)