Templated prompter (#184)

* Templated prompter

* fix dup import

* Set Verbose False by default

I forgot to disable after testing.

* Fix imports order

* Use Black Formatting

* lint

* Re-introduce lost line

* Cleanup

* template default

* isort

---------

Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
This commit is contained in:
Angainor Development 2023-03-30 01:36:04 +02:00 committed by GitHub
parent fcbc45e4c0
commit 8d58d37b65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 155 additions and 54 deletions

3
.gitignore vendored
View File

@ -10,4 +10,5 @@ lora-**
wandb wandb
evaluate.py evaluate.py
test_data.json test_data.json
todo.txt todo.txt
.vscode/

View File

@ -13,14 +13,16 @@ import torch.nn as nn
import bitsandbytes as bnb import bitsandbytes as bnb
""" """
from peft import ( # noqa: E402 from peft import (
LoraConfig, LoraConfig,
get_peft_model, get_peft_model,
get_peft_model_state_dict, get_peft_model_state_dict,
prepare_model_for_int8_training, prepare_model_for_int8_training,
set_peft_model_state_dict, set_peft_model_state_dict,
) )
from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402 from transformers import LlamaForCausalLM, LlamaTokenizer
from utils.prompter import Prompter
def train( def train(
@ -52,6 +54,7 @@ def train(
wandb_watch: str = "", # options: false | gradients | all wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter resume_from_checkpoint: str = None, # either training checkpoint or final adapter
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
): ):
if int(os.environ.get("LOCAL_RANK", 0)) == 0: if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print( print(
@ -75,13 +78,16 @@ def train(
f"wandb_run_name: {wandb_run_name}\n" f"wandb_run_name: {wandb_run_name}\n"
f"wandb_watch: {wandb_watch}\n" f"wandb_watch: {wandb_watch}\n"
f"wandb_log_model: {wandb_log_model}\n" f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint}\n" f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n"
) )
assert ( assert (
base_model base_model
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
gradient_accumulation_steps = batch_size // micro_batch_size gradient_accumulation_steps = batch_size // micro_batch_size
prompter = Prompter(prompt_template_name)
device_map = "auto" device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1)) world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1 ddp = world_size != 1
@ -138,10 +144,16 @@ def train(
return result return result
def generate_and_tokenize_prompt(data_point): def generate_and_tokenize_prompt(data_point):
full_prompt = generate_prompt(data_point) full_prompt = prompter.generate_prompt(
data_point["instruction"],
data_point["input"],
data_point["output"],
)
tokenized_full_prompt = tokenize(full_prompt) tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs: if not train_on_inputs:
user_prompt = generate_prompt({**data_point, "output": ""}) user_prompt = prompter.generate_prompt(
data_point["instruction"], data_point["input"]
)
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False) tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
user_prompt_len = len(tokenized_user_prompt["input_ids"]) user_prompt_len = len(tokenized_user_prompt["input_ids"])
@ -260,28 +272,5 @@ def train(
) )
def generate_prompt(data_point):
# sorry about the formatting disaster gotta move fast
if data_point["input"]:
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501
### Instruction:
{data_point["instruction"]}
### Input:
{data_point["input"]}
### Response:
{data_point["output"]}"""
else:
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501
### Instruction:
{data_point["instruction"]}
### Response:
{data_point["output"]}"""
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(train) fire.Fire(train)

View File

@ -7,6 +7,8 @@ import transformers
from peft import PeftModel from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
from utils.prompter import Prompter
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda" device = "cuda"
else: else:
@ -23,12 +25,15 @@ def main(
load_8bit: bool = False, load_8bit: bool = False,
base_model: str = "", base_model: str = "",
lora_weights: str = "tloen/alpaca-lora-7b", lora_weights: str = "tloen/alpaca-lora-7b",
prompt_template: str = "", # The prompt template to use, will default to alpaca.
server_name: str = "127.0.0.1", # Allows to listen on all interfaces by providing '0.0.0.0'
share_gradio: bool = False, share_gradio: bool = False,
): ):
assert ( assert (
base_model base_model
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
prompter = Prompter(prompt_template)
tokenizer = LlamaTokenizer.from_pretrained(base_model) tokenizer = LlamaTokenizer.from_pretrained(base_model)
if device == "cuda": if device == "cuda":
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
@ -86,7 +91,7 @@ def main(
max_new_tokens=128, max_new_tokens=128,
**kwargs, **kwargs,
): ):
prompt = generate_prompt(instruction, input) prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt") inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device) input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig( generation_config = GenerationConfig(
@ -106,7 +111,7 @@ def main(
) )
s = generation_output.sequences[0] s = generation_output.sequences[0]
output = tokenizer.decode(s) output = tokenizer.decode(s)
return output.split("### Response:")[1].strip() return prompter.get_response(output)
gr.Interface( gr.Interface(
fn=evaluate, fn=evaluate,
@ -141,7 +146,7 @@ def main(
], ],
title="🦙🌲 Alpaca-LoRA", title="🦙🌲 Alpaca-LoRA",
description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).", # noqa: E501 description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).", # noqa: E501
).launch(share=share_gradio) ).launch(server_name=server_name, share=share_gradio)
# Old testing code follows. # Old testing code follows.
""" """
@ -163,27 +168,5 @@ def main(
""" """
def generate_prompt(instruction, input=None):
if input:
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501
### Instruction:
{instruction}
### Input:
{input}
### Response:
"""
else:
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501
### Instruction:
{instruction}
### Response:
"""
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(main) fire.Fire(main)

46
templates/README.md Normal file
View File

@ -0,0 +1,46 @@
# Prompt templates
This directory contains template styles for the prompts used to finetune LoRA models.
## Format
A template is described via a JSON file with the following keys:
- `prompt_input`: The template to use when input is not None. Uses `{instruction}` and `{input}` placeholders.
- `prompt_no_input`: The template to use when input is None. Uses `{instruction}` placeholders.
- `description`: A short description of the template, with possible use cases.
- `response_split`: The text to use as separator when cutting real response from the model output.
No `{response}` placeholder was used, since the response is always the last element of the template and is just to be concatenated to the rest.
## Example template
The default template, used unless otherwise specified, is `alpaca.json`
```json
{
"description": "Template used by Alpaca-LoRA.",
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
"response_split": "### Response:"
}
```
## Current templates
### alpaca
Default template used for generic LoRA fine tunes so far.
### alpaca_legacy
Legacy template used by the original alpaca repo, with no `\n` after the response field. Kept for reference and experiments.
### alpaca_short
A trimmed down alpaca template which seems to perform just as well and spare some tokens. Models created with the default template seem to be queryable by the short tempalte as well. More experiments are welcome.
### vigogne
The default alpaca template, translated to french. This template was used to train the "Vigogne" LoRA and is to be used to query it, or for extra fine tuning.

6
templates/alpaca.json Normal file
View File

@ -0,0 +1,6 @@
{
"description": "Template used by Alpaca-LoRA.",
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
"response_split": "### Response:"
}

View File

@ -0,0 +1,6 @@
{
"description": "Legacy template, used by Original Alpaca repository.",
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:",
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:",
"response_split": "### Response:"
}

View File

@ -0,0 +1,6 @@
{
"description": "A shorter template to experiment with.",
"prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
"prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n",
"response_split": "### Response:"
}

6
templates/vigogne.json Normal file
View File

@ -0,0 +1,6 @@
{
"description": "French template, used by Vigogne for finetuning.",
"prompt_input": "Ci-dessous se trouve une instruction qui décrit une tâche, associée à une entrée qui fournit un contexte supplémentaire. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Entrée:\n{input}\n\n### Réponse:\n",
"prompt_no_input": "Ci-dessous se trouve une instruction qui décrit une tâche. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Réponse:\n",
"response_split": "### Réponse:"
}

7
utils/README.md Normal file
View File

@ -0,0 +1,7 @@
# Directory for helpers modules
## prompter.py
Prompter class, a template manager.
`from utils.prompter import Prompter`

0
utils/__init__.py Normal file
View File

51
utils/prompter.py Normal file
View File

@ -0,0 +1,51 @@
"""
A dedicated helper to manage templates and prompt building.
"""
import json
import os.path as osp
from typing import Union
class Prompter(object):
__slots__ = ("template", "_verbose")
def __init__(self, template_name: str = "", verbose: bool = False):
self._verbose = verbose
if not template_name:
# Enforce the default here, so the constructor can be called with '' and will not break.
template_name = "alpaca"
file_name = osp.join("templates", f"{template_name}.json")
if not osp.exists(file_name):
raise ValueError(f"Can't read {file_name}")
with open(file_name) as fp:
self.template = json.load(fp)
if self._verbose:
print(
f"Using prompt template {template_name}: {self.template['description']}"
)
def generate_prompt(
self,
instruction: str,
input: Union[None, str] = None,
label: Union[None, str] = None,
) -> str:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = self.template["prompt_input"].format(
instruction=instruction, input=input
)
else:
res = self.template["prompt_no_input"].format(
instruction=instruction
)
if label:
res = f"{res}{label}"
if self._verbose:
print(res)
return res
def get_response(self, output: str) -> str:
return output.split(self.template["response_split"])[1].strip()