From eb30f4441f855342c2dddabdc3e152e6ee919219 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 19 Jun 2023 12:31:24 -0300 Subject: [PATCH] Add ExLlama+LoRA support (#2756) --- modules/LoRA.py | 170 +++++++++++++++++++++--------------- modules/exllama.py | 9 +- modules/relative_imports.py | 13 +++ 3 files changed, 119 insertions(+), 73 deletions(-) create mode 100644 modules/relative_imports.py diff --git a/modules/LoRA.py b/modules/LoRA.py index 0803f928..fb49ae6f 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -7,85 +7,117 @@ import modules.shared as shared from modules.logging_colors import logger from modules.models import reload_model -try: - from auto_gptq import get_gptq_peft_model - from auto_gptq.utils.peft_utils import GPTQLoraConfig - has_auto_gptq_peft = True -except: - has_auto_gptq_peft = False - def add_lora_to_model(lora_names): + if 'GPTQForCausalLM' in shared.model.__class__.__name__: + add_lora_autogptq(lora_names) + elif shared.model.__class__.__name__ == 'ExllamaModel': + add_lora_exllama(lora_names) + else: + add_lora_transformers(lora_names) + + +def add_lora_exllama(lora_names): + + try: + from repositories.exllama.lora import ExLlamaLora + except: + logger.error("Could not find the file repositories/exllama/lora.py. Make sure that exllama is cloned inside repositories/ and is up to date.") + return + + if len(lora_names) == 0: + shared.model.generator.lora = None + shared.lora_names = [] + return + else: + if len(lora_names) > 1: + logger.warning('ExLlama can only work with 1 LoRA at the moment. Only the first one in the list will be loaded.') + + lora_path = Path(f"{shared.args.lora_dir}/{lora_names[0]}") + lora_config_path = lora_path / "adapter_config.json" + lora_adapter_path = lora_path / "adapter_model.bin" + + logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join([lora_names[0]]))) + lora = ExLlamaLora(shared.model.model, str(lora_config_path), str(lora_adapter_path)) + shared.model.generator.lora = lora + shared.lora_names = [lora_names[0]] + return + + +# Adapted from https://github.com/Ph0rk0z/text-generation-webui-testing +def add_lora_autogptq(lora_names): + + try: + from auto_gptq import get_gptq_peft_model + from auto_gptq.utils.peft_utils import GPTQLoraConfig + except: + logger.error("This version of AutoGPTQ does not support LoRA. You need to install from source or wait for a new release.") + return + + if len(lora_names) == 0: + if len(shared.lora_names) > 0: + reload_model() + + shared.lora_names = [] + return + else: + if len(lora_names) > 1: + logger.warning('AutoGPTQ can only work with 1 LoRA at the moment. Only the first one in the list will be loaded.') + + peft_config = GPTQLoraConfig( + inference_mode=True, + ) + + lora_path = Path(f"{shared.args.lora_dir}/{lora_names[0]}") + logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join([lora_names[0]]))) + shared.model = get_gptq_peft_model(shared.model, peft_config, lora_path) + shared.lora_names = [lora_names[0]] + return + + +def add_lora_transformers(lora_names): prior_set = set(shared.lora_names) added_set = set(lora_names) - prior_set removed_set = prior_set - set(lora_names) - shared.lora_names = list(lora_names) - is_autogptq = 'GPTQForCausalLM' in shared.model.__class__.__name__ + # If no LoRA needs to be added or removed, exit + if len(added_set) == 0 and len(removed_set) == 0: + return - # AutoGPTQ case. It doesn't use the peft functions. - # Copied from https://github.com/Ph0rk0z/text-generation-webui-testing - if is_autogptq: - if not has_auto_gptq_peft: - logger.error("This version of AutoGPTQ does not support LoRA. You need to install from source or wait for a new release.") - return + # Add a LoRA when another LoRA is already present + if len(removed_set) == 0 and len(prior_set) > 0: + logger.info(f"Adding the LoRA(s) named {added_set} to the model...") + for lora in added_set: + shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora) - if len(prior_set) > 0: - reload_model() + return - if len(shared.lora_names) == 0: - return - else: - if len(shared.lora_names) > 1: - logger.warning('AutoGPTQ can only work with 1 LoRA at the moment. Only the first one in the list will be loaded') + # If any LoRA needs to be removed, start over + if len(removed_set) > 0: + shared.model.disable_adapter() + shared.model = shared.model.base_model.model - peft_config = GPTQLoraConfig( - inference_mode=True, - ) + if len(lora_names) > 0: + params = {} + if not shared.args.cpu: + params['dtype'] = shared.model.dtype + if hasattr(shared.model, "hf_device_map"): + params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()} + elif shared.args.load_in_8bit: + params['device_map'] = {'': 0} - lora_path = Path(f"{shared.args.lora_dir}/{shared.lora_names[0]}") - logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join([lora_names[0]]))) - shared.model = get_gptq_peft_model(shared.model, peft_config, lora_path) - return + logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names))) + shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_names[0]}"), adapter_name=lora_names[0], **params) + for lora in lora_names[1:]: + shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora) - # Transformers case - else: - # If no LoRA needs to be added or removed, exit - if len(added_set) == 0 and len(removed_set) == 0: - return + shared.lora_names = lora_names - # Add a LoRA when another LoRA is already present - if len(removed_set) == 0 and len(prior_set) > 0: - logger.info(f"Adding the LoRA(s) named {added_set} to the model...") - for lora in added_set: - shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora) - - return - - # If any LoRA needs to be removed, start over - if len(removed_set) > 0: - shared.model.disable_adapter() - shared.model = shared.model.base_model.model - - if len(lora_names) > 0: - logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names))) - params = {} - if not shared.args.cpu: - params['dtype'] = shared.model.dtype - if hasattr(shared.model, "hf_device_map"): - params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()} - elif shared.args.load_in_8bit: - params['device_map'] = {'': 0} - - shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_names[0]}"), adapter_name=lora_names[0], **params) - for lora in lora_names[1:]: - shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora) - - if not shared.args.load_in_8bit and not shared.args.cpu: - shared.model.half() - if not hasattr(shared.model, "hf_device_map"): - if torch.has_mps: - device = torch.device('mps') - shared.model = shared.model.to(device) - else: - shared.model = shared.model.cuda() + if not shared.args.load_in_8bit and not shared.args.cpu: + shared.model.half() + if not hasattr(shared.model, "hf_device_map"): + if torch.has_mps: + device = torch.device('mps') + shared.model = shared.model.to(device) + else: + shared.model = shared.model.cuda() diff --git a/modules/exllama.py b/modules/exllama.py index 69eccbf2..f8c04971 100644 --- a/modules/exllama.py +++ b/modules/exllama.py @@ -3,11 +3,12 @@ from pathlib import Path from modules import shared from modules.logging_colors import logger +from modules.relative_imports import RelativeImport -sys.path.insert(0, str(Path("repositories/exllama"))) -from repositories.exllama.generator import ExLlamaGenerator -from repositories.exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig -from repositories.exllama.tokenizer import ExLlamaTokenizer +with RelativeImport("repositories/exllama"): + from generator import ExLlamaGenerator + from model import ExLlama, ExLlamaCache, ExLlamaConfig + from tokenizer import ExLlamaTokenizer class ExllamaModel: diff --git a/modules/relative_imports.py b/modules/relative_imports.py new file mode 100644 index 00000000..3c0eb56b --- /dev/null +++ b/modules/relative_imports.py @@ -0,0 +1,13 @@ +import sys +from pathlib import Path + + +class RelativeImport: + def __init__(self, path): + self.import_path = Path(path) + + def __enter__(self): + sys.path.insert(0, str(self.import_path)) + + def __exit__(self, exc_type, exc_value, traceback): + sys.path.remove(str(self.import_path))