2023-03-15 20:17:32 -04:00
import json
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
import os
2023-03-15 20:17:32 -04:00
import torch
2023-03-16 15:08:13 -04:00
import transformers
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
# Unused imports
# from peft import LoraConfig
from peft import PeftModel
2023-03-16 15:08:13 -04:00
assert (
" LlamaTokenizer " in transformers . _import_structure [ " models.llama " ]
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
) , " LLaMA is now in HuggingFace ' s main branch. \n Please reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git " # noqa: E501
from transformers import LlamaForCausalLM , LlamaTokenizer # noqa: E402
2023-03-15 20:17:32 -04:00
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
BASE_MODEL = os . environ . get ( " BASE_MODEL " , None )
2023-03-23 16:54:39 -04:00
assert (
BASE_MODEL
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
) , " Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf` " # noqa: E501
2023-03-23 16:54:39 -04:00
tokenizer = LlamaTokenizer . from_pretrained ( BASE_MODEL )
2023-03-15 20:17:32 -04:00
2023-03-16 10:34:33 -04:00
base_model = LlamaForCausalLM . from_pretrained (
2023-03-23 16:54:39 -04:00
BASE_MODEL ,
2023-03-15 20:17:32 -04:00
load_in_8bit = False ,
torch_dtype = torch . float16 ,
device_map = { " " : " cpu " } ,
)
lora_model = PeftModel . from_pretrained (
base_model ,
" tloen/alpaca-lora-7b " ,
device_map = { " " : " cpu " } ,
torch_dtype = torch . float16 ,
)
2023-03-16 03:50:24 -04:00
# merge weights
for layer in lora_model . base_model . model . model . layers :
layer . self_attn . q_proj . merge_weights = True
layer . self_attn . v_proj . merge_weights = True
2023-03-16 15:08:13 -04:00
2023-03-16 03:50:24 -04:00
lora_model . train ( False )
2023-03-15 20:17:32 -04:00
lora_model_sd = lora_model . state_dict ( )
params = {
" dim " : 4096 ,
" multiple_of " : 256 ,
" n_heads " : 32 ,
" n_layers " : 32 ,
" norm_eps " : 1e-06 ,
" vocab_size " : - 1 ,
}
n_layers = params [ " n_layers " ]
n_heads = params [ " n_heads " ]
dim = params [ " dim " ]
dims_per_head = dim / / n_heads
base = 10000.0
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
inv_freq = 1.0 / (
base * * ( torch . arange ( 0 , dims_per_head , 2 ) . float ( ) / dims_per_head )
)
2023-03-15 20:17:32 -04:00
def permute ( w ) :
return (
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
w . view ( n_heads , dim / / n_heads / / 2 , 2 , dim )
. transpose ( 1 , 2 )
. reshape ( dim , dim )
2023-03-15 20:17:32 -04:00
)
def unpermute ( w ) :
return (
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
w . view ( n_heads , 2 , dim / / n_heads / / 2 , dim )
. transpose ( 1 , 2 )
. reshape ( dim , dim )
2023-03-15 20:17:32 -04:00
)
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
def translate_state_dict_key ( k ) : # noqa: C901
2023-03-15 20:17:32 -04:00
k = k . replace ( " base_model.model. " , " " )
if k == " model.embed_tokens.weight " :
return " tok_embeddings.weight "
elif k == " model.norm.weight " :
return " norm.weight "
elif k == " lm_head.weight " :
return " output.weight "
elif k . startswith ( " model.layers. " ) :
layer = k . split ( " . " ) [ 2 ]
if k . endswith ( " .self_attn.q_proj.weight " ) :
return f " layers. { layer } .attention.wq.weight "
elif k . endswith ( " .self_attn.k_proj.weight " ) :
return f " layers. { layer } .attention.wk.weight "
elif k . endswith ( " .self_attn.v_proj.weight " ) :
return f " layers. { layer } .attention.wv.weight "
elif k . endswith ( " .self_attn.o_proj.weight " ) :
return f " layers. { layer } .attention.wo.weight "
elif k . endswith ( " .mlp.gate_proj.weight " ) :
return f " layers. { layer } .feed_forward.w1.weight "
elif k . endswith ( " .mlp.down_proj.weight " ) :
return f " layers. { layer } .feed_forward.w2.weight "
elif k . endswith ( " .mlp.up_proj.weight " ) :
return f " layers. { layer } .feed_forward.w3.weight "
elif k . endswith ( " .input_layernorm.weight " ) :
return f " layers. { layer } .attention_norm.weight "
elif k . endswith ( " .post_attention_layernorm.weight " ) :
return f " layers. { layer } .ffn_norm.weight "
elif k . endswith ( " rotary_emb.inv_freq " ) or " lora " in k :
return None
else :
print ( layer , k )
raise NotImplementedError
else :
print ( k )
raise NotImplementedError
new_state_dict = { }
for k , v in lora_model_sd . items ( ) :
new_k = translate_state_dict_key ( k )
if new_k is not None :
if " wq " in new_k or " wk " in new_k :
new_state_dict [ new_k ] = unpermute ( v )
else :
new_state_dict [ new_k ] = v
os . makedirs ( " ./ckpt " , exist_ok = True )
torch . save ( new_state_dict , " ./ckpt/consolidated.00.pth " )
with open ( " ./ckpt/params.json " , " w " ) as f :
json . dump ( params , f )