mirror of
https://github.com/tloen/alpaca-lora.git
synced 2024-10-01 01:05:56 -04:00
HF export script
This commit is contained in:
parent
8aecde83cd
commit
3b160d745b
@ -35,10 +35,12 @@ as well as some code related to prompt construction and tokenization.
|
||||
Near the top of this file is a set of hardcoded hyperparameters that you should feel free to modify.
|
||||
PRs adapting this code to multi-GPU setups and larger models are always welcome.
|
||||
|
||||
### Checkpoint export (`export_state_dict_checkpoint.py`)
|
||||
### Checkpoint export (`export_*_checkpoint.py`)
|
||||
|
||||
This file contains a script to convert the LoRA back into a standard PyTorch model checkpoint,
|
||||
which should help users who want to use the model with projects like [llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||
These files contain scripts that merge the LoRA weights back into the base model
|
||||
for export to Hugging Face format and to PyTorch `state_dicts`,
|
||||
which should help users who want to export LlamaModel-shaped weights or
|
||||
use the model with projects like [llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||
|
||||
### Dataset
|
||||
|
||||
@ -56,7 +58,6 @@ as well as [clusters of bad examples](https://atlas.nomic.ai/map/d2139cc3-bc1c-4
|
||||
- We can likely improve our model performance significantly if we combed through the data and fixed bad examples; in fact, dataset quality might be our bottleneck.
|
||||
- We're continually fixing bugs and conducting training runs, and the weights on the Hugging Face Hub are being updated accordingly. In particular, those facing issues with response lengths should make sure that they have the latest version of the weights and code.
|
||||
|
||||
|
||||
### Example outputs
|
||||
|
||||
**Instruction**: Tell me about alpacas.
|
||||
|
56
export_hf_checkpoint.py
Normal file
56
export_hf_checkpoint.py
Normal file
@ -0,0 +1,56 @@
|
||||
import os
|
||||
import json
|
||||
|
||||
import torch
|
||||
from peft import PeftModel, LoraConfig
|
||||
|
||||
import transformers
|
||||
|
||||
assert (
|
||||
"LlamaTokenizer" in transformers._import_structure["models.llama"]
|
||||
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
||||
from transformers import LlamaTokenizer, LlamaForCausalLM
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
|
||||
|
||||
base_model = LlamaForCausalLM.from_pretrained(
|
||||
"decapoda-research/llama-7b-hf",
|
||||
load_in_8bit=False,
|
||||
torch_dtype=torch.float16,
|
||||
device_map={"": "cpu"},
|
||||
)
|
||||
|
||||
first_weight = base_model.model.layers[0].self_attn.q_proj.weight
|
||||
first_weight_old = first_weight.clone()
|
||||
|
||||
lora_model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
"tloen/alpaca-lora-7b",
|
||||
device_map={"": "cpu"},
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
lora_weight = lora_model.base_model.model.model.layers[0].self_attn.q_proj.weight
|
||||
|
||||
assert torch.allclose(first_weight_old, first_weight)
|
||||
|
||||
# merge weights
|
||||
for layer in lora_model.base_model.model.model.layers:
|
||||
layer.self_attn.q_proj.merge_weights = True
|
||||
layer.self_attn.v_proj.merge_weights = True
|
||||
|
||||
lora_model.train(False)
|
||||
|
||||
# did we do anything?
|
||||
assert not torch.allclose(first_weight_old, first_weight)
|
||||
|
||||
lora_model_sd = lora_model.state_dict()
|
||||
deloreanized_sd = {
|
||||
k.replace("base_model.model.model", "model"): v
|
||||
for k, v in lora_model_sd.items()
|
||||
if "lora" not in k
|
||||
}
|
||||
|
||||
LlamaForCausalLM.save_pretrained(
|
||||
base_model, "./hf_ckpt", state_dict=deloreanized_sd, max_shard_size="400MB"
|
||||
)
|
Loading…
Reference in New Issue
Block a user