Merge pull request #216 from tatsu-lab/hf-migrate

let training code run with huggingface transformers main
This commit is contained in:
Xuechen Li 2023-04-15 15:50:02 -07:00 committed by GitHub
commit e408b27bfd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 126 additions and 45 deletions

View File

@ -15,6 +15,7 @@ This is the repo for the Stanford Alpaca project, which aims to build and share
- The [52K data](#data-release) used for fine-tuning the model.
- The code for [generating the data](#data-generation-process).
- The code for [fine-tuning the model](#fine-tuning).
- The code for [recovering Alpaca-7B weights from our released weight diff](#recovering-alpaca-weights).
Note: We thank the community for feedback on Stanford-Alpaca and supporting our research. Our live demo is suspended until further notice.
@ -115,10 +116,7 @@ We fine-tune LLaMA-7B and LLaMA-13B with the following hyperparameters:
| Max length | 512 | 512 |
| Weight decay | 0 | 0 |
We have also fine-tuned larger variants of LLaMA and are in the process of evaluating those models.
Given Hugging Face hasn't officially supported the LLaMA models, we fine-tuned LLaMA with Hugging Face's transformers library by installing it from a particular fork (i.e. this [PR](https://github.com/huggingface/transformers/pull/21955) to be merged).
The hash of the specific commit we installed was `68d640f7c368bcaaaecfc678f11908ebbd3d6176`.
We have also fine-tuned larger variants of LLaMA and performed subsequent RLHF and are in the process of evaluating those models.
To reproduce our fine-tuning runs for LLaMA, first install the requirements
@ -153,20 +151,10 @@ torchrun --nproc_per_node=4 --master_port=<your_random_port> train.py \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LLaMADecoderLayer' \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--tf32 True
```
### Warning
`fsdp_transformer_layer_cls_to_wrap` must be set to the name of the specific decoder layer.
The LLaMA Hugging Face PR is not stable.
Earlier commits used the name `LLaMADecoderLayer` for their decoder layer (the commit hash our code is based on this).
More recent commits use `LlamaDecoderLayer` (notice the small case difference).
Not setting `fsdp_transformer_layer_cls_to_wrap` to the correct name will lead to drastic slowdowns in training.
### Side notes
The same script also works for OPT fine-tuning. Here's an example for fine-tuning OPT-6.7B
```bash
@ -196,6 +184,58 @@ torchrun --nproc_per_node=4 --master_port=<your_random_port> train.py \
Note the given training script is meant to be simple and easy to use, and is not particularly optimized.
To run on more gpus, you may prefer to turn down `gradient_accumulation_steps` to keep a global batch size of 128. Global batch size has not been tested for optimality.
### Addressing OOM
Naively, fine-tuning a 7B model requires about 7 x 4 x 4 = 112 GB of VRAM. Commands given above enable parameter sharding, so no redundant model copy is stored on any GPU.
If you'd like to further reduce the memory footprint, here are some options:
- Turn on CPU offload for FSDP with `--fsdp "full_shard auto_wrap offload"`. This saves VRAM at the cost longer runtime.
- In our experience, DeepSpeed stage-3 (with offload) can at times be more memory efficient than FSDP. Here's an example to use DeepSpeed stage-3 with 4 GPUs with both parameter and optimizer offload:
```bash
pip install deepspeed
torchrun --nproc_per_node=4 --master_port=<your_random_port> train.py \
--model_name_or_path <your_path_to_hf_converted_llama_ckpt_and_tokenizer> \
--data_path ./alpaca_data.json \
--bf16 True \
--output_dir <your_output_dir> \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 2000 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--deepspeed "./configs/default_offload_opt_param.json" \
--tf32 True
```
- The DeepSpeed library also provides some [helpful functions](https://deepspeed.readthedocs.io/en/latest/memory.html) to estimate memory usage.
- [LoRA](https://arxiv.org/abs/2106.09685) fine-tunes low-rank slices of the query, key, and value embeddings. This can reduce the total memory footprint from 112GB to about 7x4=28GB. We may release our re-implemention of this in the future, but for now the [peft](https://github.com/huggingface/peft) codebase can be a useful resource.
## Recovering Alpaca Weights
The weight diff between Alpaca-7B and LLaMA-7B is located [here](https://huggingface.co/tatsu-lab/alpaca-7b-wdiff/tree/main).
To recover the original Alpaca-7B weights, follow these steps:
```text
1. Convert Meta's released weights into huggingface format. Follow this guide:
https://huggingface.co/docs/transformers/main/model_doc/llama
2. Make sure you cloned the released weight diff into your local machine. The weight diff is located at:
https://huggingface.co/tatsu-lab/alpaca-7b/tree/main
3. Run this function with the correct paths. E.g.,
python weight_diff.py recover --path_raw <path_to_step_1_dir> --path_diff <path_to_step_2_dir> --path_tuned <path_to_store_recovered_weights>
```
Once step 3 completes, you should have a directory with the recovered weights, from which you can load the model like the following
```python
import transformers
alpaca_model = transformers.AutoModelForCausalLM.from_pretrained("<path_to_store_recovered_weights>")
alpaca_tokenizer = transformers.AutoTokenizer.from_pretrained("<path_to_store_recovered_weights>")
```
### Authors
All grad students below contributed equally and the order is determined by random draw.

View File

@ -0,0 +1,49 @@
{
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"total_num_steps": "auto",
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": false
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 5,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@ -2,8 +2,8 @@ numpy
rouge_score
fire
openai
transformers>=4.26.1
transformers>=4.28.1
torch
sentencepiece
tokenizers==0.12.1
tokenizers>=0.13.3
wandb

View File

@ -15,20 +15,19 @@
import copy
import logging
from dataclasses import dataclass, field
from typing import Optional, Dict, Sequence
from typing import Dict, Optional, Sequence
import torch
import transformers
import utils
from torch.utils.data import Dataset
from transformers import Trainer
import utils
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
@ -63,15 +62,6 @@ class TrainingArguments(transformers.TrainingArguments):
)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
@ -205,26 +195,27 @@ def train():
padding_side="right",
use_fast=False,
)
special_tokens_dict = dict()
if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
tokenizer=tokenizer,
model=model,
)
if "llama" in model_args.model_name_or_path:
tokenizer.add_special_tokens(
{
"eos_token": DEFAULT_EOS_TOKEN,
"bos_token": DEFAULT_BOS_TOKEN,
"unk_token": DEFAULT_UNK_TOKEN,
}
)
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
if tokenizer.eos_token is None:
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
if tokenizer.bos_token is None:
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
if tokenizer.unk_token is None:
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
smart_tokenizer_and_embedding_resize(
special_tokens_dict=special_tokens_dict,
tokenizer=tokenizer,
model=model,
)
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
trainer.train()
trainer.save_state()
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
trainer.save_model(output_dir=training_args.output_dir)
if __name__ == "__main__":

View File

@ -133,6 +133,7 @@ def recover(
if path_tuned is not None:
model_recovered.save_pretrained(path_tuned)
tokenizer_recovered.save_pretrained(path_tuned)
if test_inference:
input_text = (