mirror of
https://github.com/tloen/alpaca-lora.git
synced 2024-10-01 01:05:56 -04:00
Update export_hf_checkpoint.py (#302)
* Update export_hf_checkpoint.py * Update finetune.py New tokenizer base model for the current dev branch of transformers * Update generate.py * Update export_state_dict_checkpoint.py * Update export_hf_checkpoint.py
This commit is contained in:
parent
19ea31660c
commit
630d1146c8
@ -8,7 +8,7 @@ from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
|
||||
BASE_MODEL = os.environ.get("BASE_MODEL", None)
|
||||
assert (
|
||||
BASE_MODEL
|
||||
), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`" # noqa: E501
|
||||
), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=huggyllama/llama-7b`" # noqa: E501
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
|
||||
|
||||
@ -35,10 +35,8 @@ lora_weight = lora_model.base_model.model.model.layers[
|
||||
|
||||
assert torch.allclose(first_weight_old, first_weight)
|
||||
|
||||
# merge weights
|
||||
for layer in lora_model.base_model.model.model.layers:
|
||||
layer.self_attn.q_proj.merge_weights = True
|
||||
layer.self_attn.v_proj.merge_weights = True
|
||||
# merge weights - new merging method from peft
|
||||
lora_model = lora_model.merge_and_unload()
|
||||
|
||||
lora_model.train(False)
|
||||
|
||||
|
@ -9,7 +9,7 @@ from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: E402
|
||||
BASE_MODEL = os.environ.get("BASE_MODEL", None)
|
||||
assert (
|
||||
BASE_MODEL
|
||||
), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`" # noqa: E501
|
||||
), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=huggyllama/llama-7b`" # noqa: E501
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
|
||||
|
||||
|
@ -83,7 +83,7 @@ def train(
|
||||
)
|
||||
assert (
|
||||
base_model
|
||||
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
||||
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
|
||||
gradient_accumulation_steps = batch_size // micro_batch_size
|
||||
|
||||
prompter = Prompter(prompt_template_name)
|
||||
|
@ -34,7 +34,7 @@ def main(
|
||||
base_model = base_model or os.environ.get("BASE_MODEL", "")
|
||||
assert (
|
||||
base_model
|
||||
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
||||
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
|
||||
|
||||
prompter = Prompter(prompt_template)
|
||||
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
||||
|
Loading…
Reference in New Issue
Block a user