Update export_hf_checkpoint.py (#302)

* Update export_hf_checkpoint.py

* Update finetune.py

New tokenizer base model for the current dev branch of transformers

* Update generate.py

* Update export_state_dict_checkpoint.py

* Update export_hf_checkpoint.py
This commit is contained in:
Lily 2023-04-09 16:07:59 -05:00 committed by GitHub
parent 19ea31660c
commit 630d1146c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 6 additions and 8 deletions

View File

@ -8,7 +8,7 @@ from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
BASE_MODEL = os.environ.get("BASE_MODEL", None) BASE_MODEL = os.environ.get("BASE_MODEL", None)
assert ( assert (
BASE_MODEL BASE_MODEL
), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`" # noqa: E501 ), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=huggyllama/llama-7b`" # noqa: E501
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL) tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
@ -35,10 +35,8 @@ lora_weight = lora_model.base_model.model.model.layers[
assert torch.allclose(first_weight_old, first_weight) assert torch.allclose(first_weight_old, first_weight)
# merge weights # merge weights - new merging method from peft
for layer in lora_model.base_model.model.model.layers: lora_model = lora_model.merge_and_unload()
layer.self_attn.q_proj.merge_weights = True
layer.self_attn.v_proj.merge_weights = True
lora_model.train(False) lora_model.train(False)

View File

@ -9,7 +9,7 @@ from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: E402
BASE_MODEL = os.environ.get("BASE_MODEL", None) BASE_MODEL = os.environ.get("BASE_MODEL", None)
assert ( assert (
BASE_MODEL BASE_MODEL
), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`" # noqa: E501 ), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=huggyllama/llama-7b`" # noqa: E501
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL) tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)

View File

@ -83,7 +83,7 @@ def train(
) )
assert ( assert (
base_model base_model
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
gradient_accumulation_steps = batch_size // micro_batch_size gradient_accumulation_steps = batch_size // micro_batch_size
prompter = Prompter(prompt_template_name) prompter = Prompter(prompt_template_name)

View File

@ -34,7 +34,7 @@ def main(
base_model = base_model or os.environ.get("BASE_MODEL", "") base_model = base_model or os.environ.get("BASE_MODEL", "")
assert ( assert (
base_model base_model
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
prompter = Prompter(prompt_template) prompter = Prompter(prompt_template)
tokenizer = LlamaTokenizer.from_pretrained(base_model) tokenizer = LlamaTokenizer.from_pretrained(base_model)