transformers loader: multi-LoRAs support (#3120)

This commit is contained in:
Googulator 2023-10-22 21:06:22 +02:00 committed by GitHub
parent 4405513ca5
commit d0c3b407b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 3 deletions

View File

@ -2,13 +2,13 @@
| Loader | Loading 1 LoRA | Loading 2 or more LoRAs | Training LoRAs | Multimodal extension | Perplexity evaluation |
|----------------|----------------|-------------------------|----------------|----------------------|-----------------------|
| Transformers | ✅ | | ✅* | ✅ | ✅ |
| Transformers | ✅ | | ✅* | ✅ | ✅ |
| ExLlama_HF | ✅ | ❌ | ❌ | ❌ | ✅ |
| ExLlamav2_HF | ✅ | ✅ | ❌ | ❌ | ✅ |
| ExLlama | ✅ | ❌ | ❌ | ❌ | use ExLlama_HF |
| ExLlamav2 | ✅ | ✅ | ❌ | ❌ | use ExLlamav2_HF |
| AutoGPTQ | ✅ | ❌ | ❌ | ✅ | ✅ |
| GPTQ-for-LLaMa | ✅** | | ✅ | ✅ | ✅ |
| GPTQ-for-LLaMa | ✅** | | ✅ | ✅ | ✅ |
| llama.cpp | ❌ | ❌ | ❌ | ❌ | use llamacpp_HF |
| llamacpp_HF | ❌ | ❌ | ❌ | ❌ | ✅ |
| ctransformers | ❌ | ❌ | ❌ | ❌ | ❌ |

View File

@ -8,6 +8,14 @@ from modules.logging_colors import logger
from modules.models import reload_model
def merge_loras():
if len(list({shared.model.peft_config[adapter].r for adapter in shared.model.peft_config.keys()})) > 1:
logger.warning("The loaded LoRAs cannot be merged, as they have dissimilar ranks. Only the first one will be active.")
return
shared.model.add_weighted_adapter(shared.lora_names, [1] * len(shared.lora_names), "__merged")
shared.model.set_adapter("__merged")
def add_lora_to_model(lora_names):
if 'GPTQForCausalLM' in shared.model.__class__.__name__ or shared.args.loader == 'AutoGPTQ':
add_lora_autogptq(lora_names)
@ -136,11 +144,14 @@ def add_lora_transformers(lora_names):
return
# Add a LoRA when another LoRA is already present
if len(removed_set) == 0 and len(prior_set) > 0:
if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys():
logger.info(f"Adding the LoRA(s) named {added_set} to the model...")
for lora in added_set:
shared.model.load_adapter(get_lora_path(lora), lora)
if len(lora_names) > 1:
merge_loras()
return
# If any LoRA needs to be removed, start over
@ -165,6 +176,9 @@ def add_lora_transformers(lora_names):
for lora in lora_names[1:]:
shared.model.load_adapter(get_lora_path(lora), lora)
if len(lora_names) > 1:
merge_loras()
shared.lora_names = lora_names
if not shared.args.load_in_8bit and not shared.args.cpu: