diff --git a/export_hf_checkpoint.py b/export_hf_checkpoint.py index 0785426..c8b2137 100644 --- a/export_hf_checkpoint.py +++ b/export_hf_checkpoint.py @@ -8,7 +8,7 @@ from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402 BASE_MODEL = os.environ.get("BASE_MODEL", None) assert ( BASE_MODEL -), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`" # noqa: E501 +), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=huggyllama/llama-7b`" # noqa: E501 tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL) @@ -35,10 +35,8 @@ lora_weight = lora_model.base_model.model.model.layers[ assert torch.allclose(first_weight_old, first_weight) -# merge weights -for layer in lora_model.base_model.model.model.layers: - layer.self_attn.q_proj.merge_weights = True - layer.self_attn.v_proj.merge_weights = True +# merge weights - new merging method from peft +lora_model = lora_model.merge_and_unload() lora_model.train(False) diff --git a/export_state_dict_checkpoint.py b/export_state_dict_checkpoint.py index 0bb0d81..23dd9b7 100644 --- a/export_state_dict_checkpoint.py +++ b/export_state_dict_checkpoint.py @@ -9,7 +9,7 @@ from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: E402 BASE_MODEL = os.environ.get("BASE_MODEL", None) assert ( BASE_MODEL -), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`" # noqa: E501 +), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=huggyllama/llama-7b`" # noqa: E501 tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL) diff --git a/finetune.py b/finetune.py index 6760fa9..a6f6028 100644 --- a/finetune.py +++ b/finetune.py @@ -83,7 +83,7 @@ def train( ) assert ( base_model - ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" + ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'" gradient_accumulation_steps = batch_size // micro_batch_size prompter = Prompter(prompt_template_name) diff --git a/generate.py b/generate.py index 3335932..4e1a9d7 100644 --- a/generate.py +++ b/generate.py @@ -34,7 +34,7 @@ def main( base_model = base_model or os.environ.get("BASE_MODEL", "") assert ( base_model - ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" + ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'" prompter = Prompter(prompt_template) tokenizer = LlamaTokenizer.from_pretrained(base_model)