mirror of
https://github.com/tloen/alpaca-lora.git
synced 2024-10-01 01:05:56 -04:00
Update export_hf_checkpoint.py (#302)
* Update export_hf_checkpoint.py * Update finetune.py New tokenizer base model for the current dev branch of transformers * Update generate.py * Update export_state_dict_checkpoint.py * Update export_hf_checkpoint.py
This commit is contained in:
parent
19ea31660c
commit
630d1146c8
@ -8,7 +8,7 @@ from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
|
|||||||
BASE_MODEL = os.environ.get("BASE_MODEL", None)
|
BASE_MODEL = os.environ.get("BASE_MODEL", None)
|
||||||
assert (
|
assert (
|
||||||
BASE_MODEL
|
BASE_MODEL
|
||||||
), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`" # noqa: E501
|
), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=huggyllama/llama-7b`" # noqa: E501
|
||||||
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
|
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
|
||||||
|
|
||||||
@ -35,10 +35,8 @@ lora_weight = lora_model.base_model.model.model.layers[
|
|||||||
|
|
||||||
assert torch.allclose(first_weight_old, first_weight)
|
assert torch.allclose(first_weight_old, first_weight)
|
||||||
|
|
||||||
# merge weights
|
# merge weights - new merging method from peft
|
||||||
for layer in lora_model.base_model.model.model.layers:
|
lora_model = lora_model.merge_and_unload()
|
||||||
layer.self_attn.q_proj.merge_weights = True
|
|
||||||
layer.self_attn.v_proj.merge_weights = True
|
|
||||||
|
|
||||||
lora_model.train(False)
|
lora_model.train(False)
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: E402
|
|||||||
BASE_MODEL = os.environ.get("BASE_MODEL", None)
|
BASE_MODEL = os.environ.get("BASE_MODEL", None)
|
||||||
assert (
|
assert (
|
||||||
BASE_MODEL
|
BASE_MODEL
|
||||||
), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`" # noqa: E501
|
), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=huggyllama/llama-7b`" # noqa: E501
|
||||||
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
|
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ def train(
|
|||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
base_model
|
base_model
|
||||||
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
|
||||||
gradient_accumulation_steps = batch_size // micro_batch_size
|
gradient_accumulation_steps = batch_size // micro_batch_size
|
||||||
|
|
||||||
prompter = Prompter(prompt_template_name)
|
prompter = Prompter(prompt_template_name)
|
||||||
|
@ -34,7 +34,7 @@ def main(
|
|||||||
base_model = base_model or os.environ.get("BASE_MODEL", "")
|
base_model = base_model or os.environ.get("BASE_MODEL", "")
|
||||||
assert (
|
assert (
|
||||||
base_model
|
base_model
|
||||||
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
|
||||||
|
|
||||||
prompter = Prompter(prompt_template)
|
prompter = Prompter(prompt_template)
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
||||||
|
Loading…
Reference in New Issue
Block a user