2023-03-17 20:56:10 -04:00
import os
import json
import torch
from peft import PeftModel , LoraConfig
import transformers
assert (
" LlamaTokenizer " in transformers . _import_structure [ " models.llama " ]
) , " LLaMA is now in HuggingFace ' s main branch. \n Please reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git "
from transformers import LlamaTokenizer , LlamaForCausalLM
tokenizer = LlamaTokenizer . from_pretrained ( " decapoda-research/llama-7b-hf " )
base_model = LlamaForCausalLM . from_pretrained (
" decapoda-research/llama-7b-hf " ,
load_in_8bit = False ,
torch_dtype = torch . float16 ,
device_map = { " " : " cpu " } ,
)
first_weight = base_model . model . layers [ 0 ] . self_attn . q_proj . weight
first_weight_old = first_weight . clone ( )
lora_model = PeftModel . from_pretrained (
base_model ,
" tloen/alpaca-lora-7b " ,
device_map = { " " : " cpu " } ,
torch_dtype = torch . float16 ,
)
lora_weight = lora_model . base_model . model . model . layers [ 0 ] . self_attn . q_proj . weight
assert torch . allclose ( first_weight_old , first_weight )
# merge weights
for layer in lora_model . base_model . model . model . layers :
layer . self_attn . q_proj . merge_weights = True
layer . self_attn . v_proj . merge_weights = True
lora_model . train ( False )
# did we do anything?
assert not torch . allclose ( first_weight_old , first_weight )
lora_model_sd = lora_model . state_dict ( )
deloreanized_sd = {
2023-03-18 19:42:47 -04:00
k . replace ( " base_model.model. " , " " ) : v
2023-03-17 20:56:10 -04:00
for k , v in lora_model_sd . items ( )
if " lora " not in k
}
LlamaForCausalLM . save_pretrained (
base_model , " ./hf_ckpt " , state_dict = deloreanized_sd , max_shard_size = " 400MB "
)