mirror of
https://github.com/tloen/alpaca-lora.git
synced 2024-10-01 01:05:56 -04:00
Templated prompter (#184)
* Templated prompter * fix dup import * Set Verbose False by default I forgot to disable after testing. * Fix imports order * Use Black Formatting * lint * Re-introduce lost line * Cleanup * template default * isort --------- Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
This commit is contained in:
parent
fcbc45e4c0
commit
8d58d37b65
1
.gitignore
vendored
1
.gitignore
vendored
@ -11,3 +11,4 @@ wandb
|
|||||||
evaluate.py
|
evaluate.py
|
||||||
test_data.json
|
test_data.json
|
||||||
todo.txt
|
todo.txt
|
||||||
|
.vscode/
|
45
finetune.py
45
finetune.py
@ -13,14 +13,16 @@ import torch.nn as nn
|
|||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from peft import ( # noqa: E402
|
from peft import (
|
||||||
LoraConfig,
|
LoraConfig,
|
||||||
get_peft_model,
|
get_peft_model,
|
||||||
get_peft_model_state_dict,
|
get_peft_model_state_dict,
|
||||||
prepare_model_for_int8_training,
|
prepare_model_for_int8_training,
|
||||||
set_peft_model_state_dict,
|
set_peft_model_state_dict,
|
||||||
)
|
)
|
||||||
from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
|
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
|
from utils.prompter import Prompter
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
@ -52,6 +54,7 @@ def train(
|
|||||||
wandb_watch: str = "", # options: false | gradients | all
|
wandb_watch: str = "", # options: false | gradients | all
|
||||||
wandb_log_model: str = "", # options: false | true
|
wandb_log_model: str = "", # options: false | true
|
||||||
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
||||||
|
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
|
||||||
):
|
):
|
||||||
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
||||||
print(
|
print(
|
||||||
@ -75,13 +78,16 @@ def train(
|
|||||||
f"wandb_run_name: {wandb_run_name}\n"
|
f"wandb_run_name: {wandb_run_name}\n"
|
||||||
f"wandb_watch: {wandb_watch}\n"
|
f"wandb_watch: {wandb_watch}\n"
|
||||||
f"wandb_log_model: {wandb_log_model}\n"
|
f"wandb_log_model: {wandb_log_model}\n"
|
||||||
f"resume_from_checkpoint: {resume_from_checkpoint}\n"
|
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
|
||||||
|
f"prompt template: {prompt_template_name}\n"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
base_model
|
base_model
|
||||||
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
||||||
gradient_accumulation_steps = batch_size // micro_batch_size
|
gradient_accumulation_steps = batch_size // micro_batch_size
|
||||||
|
|
||||||
|
prompter = Prompter(prompt_template_name)
|
||||||
|
|
||||||
device_map = "auto"
|
device_map = "auto"
|
||||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
ddp = world_size != 1
|
ddp = world_size != 1
|
||||||
@ -138,10 +144,16 @@ def train(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def generate_and_tokenize_prompt(data_point):
|
def generate_and_tokenize_prompt(data_point):
|
||||||
full_prompt = generate_prompt(data_point)
|
full_prompt = prompter.generate_prompt(
|
||||||
|
data_point["instruction"],
|
||||||
|
data_point["input"],
|
||||||
|
data_point["output"],
|
||||||
|
)
|
||||||
tokenized_full_prompt = tokenize(full_prompt)
|
tokenized_full_prompt = tokenize(full_prompt)
|
||||||
if not train_on_inputs:
|
if not train_on_inputs:
|
||||||
user_prompt = generate_prompt({**data_point, "output": ""})
|
user_prompt = prompter.generate_prompt(
|
||||||
|
data_point["instruction"], data_point["input"]
|
||||||
|
)
|
||||||
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
|
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
|
||||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||||
|
|
||||||
@ -260,28 +272,5 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_prompt(data_point):
|
|
||||||
# sorry about the formatting disaster gotta move fast
|
|
||||||
if data_point["input"]:
|
|
||||||
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501
|
|
||||||
|
|
||||||
### Instruction:
|
|
||||||
{data_point["instruction"]}
|
|
||||||
|
|
||||||
### Input:
|
|
||||||
{data_point["input"]}
|
|
||||||
|
|
||||||
### Response:
|
|
||||||
{data_point["output"]}"""
|
|
||||||
else:
|
|
||||||
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501
|
|
||||||
|
|
||||||
### Instruction:
|
|
||||||
{data_point["instruction"]}
|
|
||||||
|
|
||||||
### Response:
|
|
||||||
{data_point["output"]}"""
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fire.Fire(train)
|
fire.Fire(train)
|
||||||
|
33
generate.py
33
generate.py
@ -7,6 +7,8 @@ import transformers
|
|||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
|
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
|
from utils.prompter import Prompter
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
else:
|
else:
|
||||||
@ -23,12 +25,15 @@ def main(
|
|||||||
load_8bit: bool = False,
|
load_8bit: bool = False,
|
||||||
base_model: str = "",
|
base_model: str = "",
|
||||||
lora_weights: str = "tloen/alpaca-lora-7b",
|
lora_weights: str = "tloen/alpaca-lora-7b",
|
||||||
|
prompt_template: str = "", # The prompt template to use, will default to alpaca.
|
||||||
|
server_name: str = "127.0.0.1", # Allows to listen on all interfaces by providing '0.0.0.0'
|
||||||
share_gradio: bool = False,
|
share_gradio: bool = False,
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
base_model
|
base_model
|
||||||
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
||||||
|
|
||||||
|
prompter = Prompter(prompt_template)
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
||||||
if device == "cuda":
|
if device == "cuda":
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
@ -86,7 +91,7 @@ def main(
|
|||||||
max_new_tokens=128,
|
max_new_tokens=128,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
prompt = generate_prompt(instruction, input)
|
prompt = prompter.generate_prompt(instruction, input)
|
||||||
inputs = tokenizer(prompt, return_tensors="pt")
|
inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
input_ids = inputs["input_ids"].to(device)
|
input_ids = inputs["input_ids"].to(device)
|
||||||
generation_config = GenerationConfig(
|
generation_config = GenerationConfig(
|
||||||
@ -106,7 +111,7 @@ def main(
|
|||||||
)
|
)
|
||||||
s = generation_output.sequences[0]
|
s = generation_output.sequences[0]
|
||||||
output = tokenizer.decode(s)
|
output = tokenizer.decode(s)
|
||||||
return output.split("### Response:")[1].strip()
|
return prompter.get_response(output)
|
||||||
|
|
||||||
gr.Interface(
|
gr.Interface(
|
||||||
fn=evaluate,
|
fn=evaluate,
|
||||||
@ -141,7 +146,7 @@ def main(
|
|||||||
],
|
],
|
||||||
title="🦙🌲 Alpaca-LoRA",
|
title="🦙🌲 Alpaca-LoRA",
|
||||||
description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).", # noqa: E501
|
description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).", # noqa: E501
|
||||||
).launch(share=share_gradio)
|
).launch(server_name=server_name, share=share_gradio)
|
||||||
# Old testing code follows.
|
# Old testing code follows.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -163,27 +168,5 @@ def main(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def generate_prompt(instruction, input=None):
|
|
||||||
if input:
|
|
||||||
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501
|
|
||||||
|
|
||||||
### Instruction:
|
|
||||||
{instruction}
|
|
||||||
|
|
||||||
### Input:
|
|
||||||
{input}
|
|
||||||
|
|
||||||
### Response:
|
|
||||||
"""
|
|
||||||
else:
|
|
||||||
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501
|
|
||||||
|
|
||||||
### Instruction:
|
|
||||||
{instruction}
|
|
||||||
|
|
||||||
### Response:
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fire.Fire(main)
|
fire.Fire(main)
|
||||||
|
46
templates/README.md
Normal file
46
templates/README.md
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# Prompt templates
|
||||||
|
|
||||||
|
This directory contains template styles for the prompts used to finetune LoRA models.
|
||||||
|
|
||||||
|
## Format
|
||||||
|
|
||||||
|
A template is described via a JSON file with the following keys:
|
||||||
|
|
||||||
|
- `prompt_input`: The template to use when input is not None. Uses `{instruction}` and `{input}` placeholders.
|
||||||
|
- `prompt_no_input`: The template to use when input is None. Uses `{instruction}` placeholders.
|
||||||
|
- `description`: A short description of the template, with possible use cases.
|
||||||
|
- `response_split`: The text to use as separator when cutting real response from the model output.
|
||||||
|
|
||||||
|
No `{response}` placeholder was used, since the response is always the last element of the template and is just to be concatenated to the rest.
|
||||||
|
|
||||||
|
## Example template
|
||||||
|
|
||||||
|
The default template, used unless otherwise specified, is `alpaca.json`
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"description": "Template used by Alpaca-LoRA.",
|
||||||
|
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
|
||||||
|
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
|
||||||
|
"response_split": "### Response:"
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## Current templates
|
||||||
|
|
||||||
|
### alpaca
|
||||||
|
|
||||||
|
Default template used for generic LoRA fine tunes so far.
|
||||||
|
|
||||||
|
### alpaca_legacy
|
||||||
|
|
||||||
|
Legacy template used by the original alpaca repo, with no `\n` after the response field. Kept for reference and experiments.
|
||||||
|
|
||||||
|
### alpaca_short
|
||||||
|
|
||||||
|
A trimmed down alpaca template which seems to perform just as well and spare some tokens. Models created with the default template seem to be queryable by the short tempalte as well. More experiments are welcome.
|
||||||
|
|
||||||
|
### vigogne
|
||||||
|
|
||||||
|
The default alpaca template, translated to french. This template was used to train the "Vigogne" LoRA and is to be used to query it, or for extra fine tuning.
|
6
templates/alpaca.json
Normal file
6
templates/alpaca.json
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"description": "Template used by Alpaca-LoRA.",
|
||||||
|
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
|
||||||
|
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
|
||||||
|
"response_split": "### Response:"
|
||||||
|
}
|
6
templates/alpaca_legacy.json
Normal file
6
templates/alpaca_legacy.json
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"description": "Legacy template, used by Original Alpaca repository.",
|
||||||
|
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:",
|
||||||
|
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:",
|
||||||
|
"response_split": "### Response:"
|
||||||
|
}
|
6
templates/alpaca_short.json
Normal file
6
templates/alpaca_short.json
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"description": "A shorter template to experiment with.",
|
||||||
|
"prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
|
||||||
|
"prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n",
|
||||||
|
"response_split": "### Response:"
|
||||||
|
}
|
6
templates/vigogne.json
Normal file
6
templates/vigogne.json
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"description": "French template, used by Vigogne for finetuning.",
|
||||||
|
"prompt_input": "Ci-dessous se trouve une instruction qui décrit une tâche, associée à une entrée qui fournit un contexte supplémentaire. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Entrée:\n{input}\n\n### Réponse:\n",
|
||||||
|
"prompt_no_input": "Ci-dessous se trouve une instruction qui décrit une tâche. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Réponse:\n",
|
||||||
|
"response_split": "### Réponse:"
|
||||||
|
}
|
7
utils/README.md
Normal file
7
utils/README.md
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
# Directory for helpers modules
|
||||||
|
|
||||||
|
## prompter.py
|
||||||
|
|
||||||
|
Prompter class, a template manager.
|
||||||
|
|
||||||
|
`from utils.prompter import Prompter`
|
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
51
utils/prompter.py
Normal file
51
utils/prompter.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
"""
|
||||||
|
A dedicated helper to manage templates and prompt building.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os.path as osp
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
|
class Prompter(object):
|
||||||
|
__slots__ = ("template", "_verbose")
|
||||||
|
|
||||||
|
def __init__(self, template_name: str = "", verbose: bool = False):
|
||||||
|
self._verbose = verbose
|
||||||
|
if not template_name:
|
||||||
|
# Enforce the default here, so the constructor can be called with '' and will not break.
|
||||||
|
template_name = "alpaca"
|
||||||
|
file_name = osp.join("templates", f"{template_name}.json")
|
||||||
|
if not osp.exists(file_name):
|
||||||
|
raise ValueError(f"Can't read {file_name}")
|
||||||
|
with open(file_name) as fp:
|
||||||
|
self.template = json.load(fp)
|
||||||
|
if self._verbose:
|
||||||
|
print(
|
||||||
|
f"Using prompt template {template_name}: {self.template['description']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_prompt(
|
||||||
|
self,
|
||||||
|
instruction: str,
|
||||||
|
input: Union[None, str] = None,
|
||||||
|
label: Union[None, str] = None,
|
||||||
|
) -> str:
|
||||||
|
# returns the full prompt from instruction and optional input
|
||||||
|
# if a label (=response, =output) is provided, it's also appended.
|
||||||
|
if input:
|
||||||
|
res = self.template["prompt_input"].format(
|
||||||
|
instruction=instruction, input=input
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
res = self.template["prompt_no_input"].format(
|
||||||
|
instruction=instruction
|
||||||
|
)
|
||||||
|
if label:
|
||||||
|
res = f"{res}{label}"
|
||||||
|
if self._verbose:
|
||||||
|
print(res)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def get_response(self, output: str) -> str:
|
||||||
|
return output.split(self.template["response_split"])[1].strip()
|
Loading…
Reference in New Issue
Block a user