mirror of
https://github.com/tatsu-lab/stanford_alpaca.git
synced 2024-10-01 05:35:37 -04:00
training code.
This commit is contained in:
parent
7ad0c6b4f7
commit
fa78c4fb25
65
README.md
65
README.md
@ -91,7 +91,7 @@ The inner circle of the plot represents the root verb of the instructions, and t
|
||||
[<img src="assets/parse_analysis.png" width="750" />](./assets/parse_analysis.png)
|
||||
|
||||
## Fine-tuning
|
||||
We fine-tune our model using standard Hugging Face training code with the following hyperparameters:
|
||||
We fine-tune our models using standard Hugging Face training code with the following hyperparameters:
|
||||
|
||||
| Hyperparameter | Value |
|
||||
|----------------|-------|
|
||||
@ -101,7 +101,68 @@ We fine-tune our model using standard Hugging Face training code with the follow
|
||||
| Max length | 512 |
|
||||
| Weight decay | 1 |
|
||||
|
||||
We are waiting for Hugging Face to officially support the llama models (i.e. this [PR](https://github.com/huggingface/transformers/pull/21955) to be merged) before we release a stable version of the finetuning code.
|
||||
Given Hugging Face hasn't officially supported the LLaMA models, we fine-tuned LLaMA with Hugging Face's transformers library by installing it from a particular fork (i.e. this [PR](https://github.com/huggingface/transformers/pull/21955) to be merged).
|
||||
The hash of the specific commit we installed was `68d640f7c368bcaaaecfc678f11908ebbd3d6176`.
|
||||
|
||||
To reproduce our fine-tuning runs for LLaMA, first install the requirements
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
Then, install the particular fork of Hugging Face's transformers library.
|
||||
|
||||
Below is a command that fine-tunes LLaMA-7B with our dataset on a machine with 4 A100 80G GPUs in FSDP `full_shard` mode.
|
||||
Replace `<your_random_port>` with a port of your own, `<your_path_to_hf_converted_llama_ckpt_and_tokenizer>` with the
|
||||
path to your converted checkpoint and tokenizer (following instructions in the PR), and `<your_output_dir>` with where you want to store your outputs.
|
||||
|
||||
```
|
||||
torchrun --nproc_per_node=4 --master_port=<your_random_port> train.py \
|
||||
--model_name_or_path <your_path_to_hf_converted_llama_ckpt_and_tokenizer> \
|
||||
--data_path ./alpaca_data.json \
|
||||
--bf16 True \
|
||||
--output_dir <your_output_dir> \
|
||||
--num_train_epochs 3 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--per_device_eval_batch_size 4 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--evaluation_strategy "no" \
|
||||
--save_strategy "steps" \
|
||||
--save_steps 2000 \
|
||||
--save_total_limit 1 \
|
||||
--learning_rate 2e-5 \
|
||||
--weight_decay 0. \
|
||||
--warmup_ratio 0.03 \
|
||||
--lr_scheduler_type "cosine" \
|
||||
--logging_steps 1 \
|
||||
--fsdp "full_shard auto_wrap" \
|
||||
--fsdp_transformer_layer_cls_to_wrap 'LLaMADecoderLayer' \
|
||||
--tf32 True
|
||||
```
|
||||
|
||||
The same script also works for OPT fine-tuning. Here's an example for fine-tuning OPT-6.7B
|
||||
|
||||
```bash
|
||||
torchrun --nproc_per_node=4 --master_port=<your_random_port> train.py \
|
||||
--model_name_or_path "facebook/opt-6.7b" \
|
||||
--data_path ./alpaca_data.json \
|
||||
--bf16 True \
|
||||
--output_dir <your_output_dir> \
|
||||
--num_train_epochs 3 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--per_device_eval_batch_size 4 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--evaluation_strategy "no" \
|
||||
--save_strategy "steps" \
|
||||
--save_steps 2000 \
|
||||
--save_total_limit 1 \
|
||||
--learning_rate 2e-5 \
|
||||
--weight_decay 0. \
|
||||
--warmup_ratio 0.03 \
|
||||
--lr_scheduler_type "cosine" \
|
||||
--logging_steps 1 \
|
||||
--fsdp "full_shard auto_wrap" \
|
||||
--fsdp_transformer_layer_cls_to_wrap 'OPTDecoderLayer' \
|
||||
--tf32 True
|
||||
```
|
||||
|
||||
### Authors
|
||||
All grad students below contributed equally and the order is determined by random draw.
|
||||
|
@ -1,4 +1,9 @@
|
||||
numpy
|
||||
rouge_score
|
||||
fire
|
||||
openai
|
||||
openai
|
||||
transformers>=4.26.1
|
||||
torch
|
||||
sentencepiece
|
||||
tokenizers==0.12.1
|
||||
wandb
|
||||
|
232
train.py
Normal file
232
train.py
Normal file
@ -0,0 +1,232 @@
|
||||
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, Sequence
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import Trainer
|
||||
|
||||
import utils
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
DEFAULT_PAD_TOKEN = "[PAD]"
|
||||
DEFAULT_EOS_TOKEN = "</s>"
|
||||
DEFAULT_BOS_TOKEN = "</s>"
|
||||
DEFAULT_UNK_TOKEN = "</s>"
|
||||
PROMPT_DICT = {
|
||||
"prompt_input": (
|
||||
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
|
||||
),
|
||||
"prompt_no_input": (
|
||||
"Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Response:"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(transformers.TrainingArguments):
|
||||
cache_dir: Optional[str] = field(default=None)
|
||||
optim: str = field(default="adamw_torch")
|
||||
model_max_length: int = field(
|
||||
default=512,
|
||||
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
||||
)
|
||||
|
||||
|
||||
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
||||
"""Collects the state dict and dump to disk."""
|
||||
state_dict = trainer.model.state_dict()
|
||||
if trainer.args.should_save:
|
||||
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
||||
del state_dict
|
||||
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
||||
|
||||
|
||||
def smart_tokenizer_and_embedding_resize(
|
||||
special_tokens_dict: Dict,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
model: transformers.PreTrainedModel,
|
||||
):
|
||||
"""Resize tokenizer and embedding.
|
||||
|
||||
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
||||
"""
|
||||
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if num_new_tokens > 0:
|
||||
input_embeddings = model.get_input_embeddings().weight.data
|
||||
output_embeddings = model.get_output_embeddings().weight.data
|
||||
|
||||
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
|
||||
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
||||
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
||||
|
||||
|
||||
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
||||
"""Tokenize a list of strings."""
|
||||
tokenized_list = [
|
||||
tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
)
|
||||
for text in strings
|
||||
]
|
||||
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
||||
input_ids_lens = labels_lens = [
|
||||
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
||||
]
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
input_ids_lens=input_ids_lens,
|
||||
labels_lens=labels_lens,
|
||||
)
|
||||
|
||||
|
||||
def preprocess(
|
||||
sources: Sequence[str],
|
||||
targets: Sequence[str],
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
) -> Dict:
|
||||
"""Preprocess the data by tokenizing."""
|
||||
examples = [s + t for s, t in zip(sources, targets)]
|
||||
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
||||
input_ids = examples_tokenized["input_ids"]
|
||||
labels = copy.deepcopy(input_ids)
|
||||
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
||||
label[:source_len] = IGNORE_INDEX
|
||||
return dict(input_ids=input_ids, labels=labels)
|
||||
|
||||
|
||||
class SupervisedDataset(Dataset):
|
||||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
|
||||
super(SupervisedDataset, self).__init__()
|
||||
logging.warning("Loading data...")
|
||||
list_data_dict = utils.jload(data_path)
|
||||
|
||||
logging.warning("Formatting inputs...")
|
||||
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
||||
sources = [
|
||||
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
|
||||
for example in list_data_dict
|
||||
]
|
||||
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
||||
|
||||
logging.warning("Tokenizing inputs... This may take some time...")
|
||||
data_dict = preprocess(sources, targets, tokenizer)
|
||||
|
||||
self.input_ids = data_dict["input_ids"]
|
||||
self.labels = data_dict["labels"]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSupervisedDataset(object):
|
||||
"""Collate examples for supervised fine-tuning."""
|
||||
|
||||
tokenizer: transformers.PreTrainedTokenizer
|
||||
|
||||
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||
)
|
||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
||||
)
|
||||
|
||||
|
||||
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
|
||||
"""Make dataset and collator for supervised fine-tuning."""
|
||||
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
||||
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
||||
|
||||
|
||||
def train():
|
||||
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
cache_dir=training_args.cache_dir,
|
||||
)
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
cache_dir=training_args.cache_dir,
|
||||
model_max_length=training_args.model_max_length,
|
||||
padding_side="right",
|
||||
use_fast=False,
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
smart_tokenizer_and_embedding_resize(
|
||||
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
)
|
||||
if "llama" in model_args.model_name_or_path:
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"eos_token": DEFAULT_EOS_TOKEN,
|
||||
"bos_token": DEFAULT_BOS_TOKEN,
|
||||
"unk_token": DEFAULT_UNK_TOKEN,
|
||||
}
|
||||
)
|
||||
|
||||
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
|
||||
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
|
||||
trainer.train()
|
||||
trainer.evaluate()
|
||||
trainer.save_state()
|
||||
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
Loading…
Reference in New Issue
Block a user