Fix PEFT LoRA unloading

This commit is contained in:
oobabooga 2023-11-19 07:55:25 -08:00
parent a290d17386
commit a6f1e1bcc5

View File

@ -149,10 +149,7 @@ def add_lora_transformers(lora_names):
# If any LoRA needs to be removed, start over # If any LoRA needs to be removed, start over
if len(removed_set) > 0: if len(removed_set) > 0:
# shared.model may no longer be PeftModel shared.model = shared.model.unload()
if hasattr(shared.model, 'disable_adapter'):
shared.model.disable_adapter()
shared.model = shared.model.base_model.model
if len(lora_names) > 0: if len(lora_names) > 0:
params = {} params = {}
@ -172,8 +169,6 @@ def add_lora_transformers(lora_names):
if len(lora_names) > 1: if len(lora_names) > 1:
merge_loras() merge_loras()
shared.lora_names = lora_names
if not shared.args.load_in_8bit and not shared.args.cpu: if not shared.args.load_in_8bit and not shared.args.cpu:
shared.model.half() shared.model.half()
if not hasattr(shared.model, "hf_device_map"): if not hasattr(shared.model, "hf_device_map"):
@ -186,6 +181,8 @@ def add_lora_transformers(lora_names):
else: else:
shared.model = shared.model.cuda() shared.model = shared.model.cuda()
shared.lora_names = lora_names
def merge_loras(): def merge_loras():
if len(list({shared.model.peft_config[adapter].r for adapter in shared.model.peft_config.keys()})) > 1: if len(list({shared.model.peft_config[adapter].r for adapter in shared.model.peft_config.keys()})) > 1: