mirror of
https://github.com/tatsu-lab/stanford_alpaca.git
synced 2024-10-01 05:35:37 -04:00
migrate to latest main for hf transformers.
This commit is contained in:
parent
7a95b21d2c
commit
3783d185b5
70
README.md
70
README.md
@ -15,6 +15,7 @@ This is the repo for the Stanford Alpaca project, which aims to build and share
|
||||
- The [52K data](#data-release) used for fine-tuning the model.
|
||||
- The code for [generating the data](#data-generation-process).
|
||||
- The code for [fine-tuning the model](#fine-tuning).
|
||||
- The code for [recovering Alpaca-7B weights from our released weight diff](#recovering-alpaca-weights).
|
||||
|
||||
Note: We thank the community for feedback on Stanford-Alpaca and supporting our research. Our live demo is suspended until further notice.
|
||||
|
||||
@ -115,10 +116,7 @@ We fine-tune LLaMA-7B and LLaMA-13B with the following hyperparameters:
|
||||
| Max length | 512 | 512 |
|
||||
| Weight decay | 0 | 0 |
|
||||
|
||||
We have also fine-tuned larger variants of LLaMA and are in the process of evaluating those models.
|
||||
|
||||
Given Hugging Face hasn't officially supported the LLaMA models, we fine-tuned LLaMA with Hugging Face's transformers library by installing it from a particular fork (i.e. this [PR](https://github.com/huggingface/transformers/pull/21955) to be merged).
|
||||
The hash of the specific commit we installed was `68d640f7c368bcaaaecfc678f11908ebbd3d6176`.
|
||||
We have also fine-tuned larger variants of LLaMA and performed subsequent RLHF and are in the process of evaluating those models.
|
||||
|
||||
To reproduce our fine-tuning runs for LLaMA, first install the requirements
|
||||
|
||||
@ -153,20 +151,10 @@ torchrun --nproc_per_node=4 --master_port=<your_random_port> train.py \
|
||||
--lr_scheduler_type "cosine" \
|
||||
--logging_steps 1 \
|
||||
--fsdp "full_shard auto_wrap" \
|
||||
--fsdp_transformer_layer_cls_to_wrap 'LLaMADecoderLayer' \
|
||||
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
|
||||
--tf32 True
|
||||
```
|
||||
|
||||
### Warning
|
||||
|
||||
`fsdp_transformer_layer_cls_to_wrap` must be set to the name of the specific decoder layer.
|
||||
The LLaMA Hugging Face PR is not stable.
|
||||
Earlier commits used the name `LLaMADecoderLayer` for their decoder layer (the commit hash our code is based on this).
|
||||
More recent commits use `LlamaDecoderLayer` (notice the small case difference).
|
||||
Not setting `fsdp_transformer_layer_cls_to_wrap` to the correct name will lead to drastic slowdowns in training.
|
||||
|
||||
### Side notes
|
||||
|
||||
The same script also works for OPT fine-tuning. Here's an example for fine-tuning OPT-6.7B
|
||||
|
||||
```bash
|
||||
@ -196,6 +184,58 @@ torchrun --nproc_per_node=4 --master_port=<your_random_port> train.py \
|
||||
Note the given training script is meant to be simple and easy to use, and is not particularly optimized.
|
||||
To run on more gpus, you may prefer to turn down `gradient_accumulation_steps` to keep a global batch size of 128. Global batch size has not been tested for optimality.
|
||||
|
||||
### Addressing OOM
|
||||
|
||||
Naively, fine-tuning a 7B model requires about 7 x 4 x 4 = 112 GB of VRAM. Commands given above enable parameter sharding, so no redundant model copy is stored on any GPU.
|
||||
If you'd like to further reduce the memory footprint, here are some options:
|
||||
|
||||
- Turn on CPU offload for FSDP with `--fsdp "full_shard auto_wrap offload"`. This saves VRAM at the cost longer runtime.
|
||||
- In our experience, DeepSpeed stage-3 (with offload) can at times be more memory efficient than FSDP. Here's an example to use DeepSpeed stage-3 with 4 GPUs with both parameter and optimizer offload:
|
||||
```bash
|
||||
pip install deepspeed
|
||||
torchrun --nproc_per_node=4 --master_port=<your_random_port> train.py \
|
||||
--model_name_or_path <your_path_to_hf_converted_llama_ckpt_and_tokenizer> \
|
||||
--data_path ./alpaca_data.json \
|
||||
--bf16 True \
|
||||
--output_dir <your_output_dir> \
|
||||
--num_train_epochs 3 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--per_device_eval_batch_size 4 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--evaluation_strategy "no" \
|
||||
--save_strategy "steps" \
|
||||
--save_steps 2000 \
|
||||
--save_total_limit 1 \
|
||||
--learning_rate 2e-5 \
|
||||
--weight_decay 0. \
|
||||
--warmup_ratio 0.03 \
|
||||
--deepspeed "./configs/default_offload_opt_param.json" \
|
||||
--tf32 True
|
||||
```
|
||||
- The DeepSpeed library also provides some [helpful functions](https://deepspeed.readthedocs.io/en/latest/memory.html) to estimate memory usage.
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685) fine-tunes low-rank slices of the query, key, and value embeddings. This can reduce the total memory footprint from 112GB to about 7x4=28GB. We may release our re-implemention of this in the future, but for now the [peft](https://github.com/huggingface/peft) codebase can be a useful resource.
|
||||
|
||||
## Recovering Alpaca Weights
|
||||
|
||||
The weight diff between Alpaca-7B and LLaMA-7B is located [here](https://huggingface.co/tatsu-lab/alpaca-7b-wdiff/tree/main).
|
||||
To recover the original Alpaca-7B weights, follow these steps:
|
||||
```text
|
||||
1. Convert Meta's released weights into huggingface format. Follow this guide:
|
||||
https://huggingface.co/docs/transformers/main/model_doc/llama
|
||||
2. Make sure you cloned the released weight diff into your local machine. The weight diff is located at:
|
||||
https://huggingface.co/tatsu-lab/alpaca-7b/tree/main
|
||||
3. Run this function with the correct paths. E.g.,
|
||||
python weight_diff.py recover --path_raw <path_to_step_1_dir> --path_diff <path_to_step_2_dir> --path_tuned <path_to_store_recovered_weights>
|
||||
```
|
||||
|
||||
Once step 3 completes, you should have a directory with the recovered weights, from which you can load the model like the following
|
||||
|
||||
```python
|
||||
import transformers
|
||||
alpaca_model = transformers.AutoModelForCausalLM.from_pretrained("<path_to_store_recovered_weights>")
|
||||
alpaca_tokenizer = transformers.AutoTokenizer.from_pretrained("<path_to_store_recovered_weights>")
|
||||
```
|
||||
|
||||
### Authors
|
||||
|
||||
All grad students below contributed equally and the order is determined by random draw.
|
||||
|
49
configs/default_offload_opt_param.json
Normal file
49
configs/default_offload_opt_param.json
Normal file
@ -0,0 +1,49 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": "auto",
|
||||
"eps": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"total_num_steps": "auto",
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": false
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 5,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
@ -2,8 +2,8 @@ numpy
|
||||
rouge_score
|
||||
fire
|
||||
openai
|
||||
transformers>=4.26.1
|
||||
transformers>=4.28.1
|
||||
torch
|
||||
sentencepiece
|
||||
tokenizers==0.12.1
|
||||
tokenizers>=0.13.3
|
||||
wandb
|
||||
|
39
train.py
39
train.py
@ -15,20 +15,19 @@
|
||||
import copy
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, Sequence
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
import utils
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import Trainer
|
||||
|
||||
import utils
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
DEFAULT_PAD_TOKEN = "[PAD]"
|
||||
DEFAULT_EOS_TOKEN = "</s>"
|
||||
DEFAULT_BOS_TOKEN = "</s>"
|
||||
DEFAULT_UNK_TOKEN = "</s>"
|
||||
DEFAULT_BOS_TOKEN = "<s>"
|
||||
DEFAULT_UNK_TOKEN = "<unk>"
|
||||
PROMPT_DICT = {
|
||||
"prompt_input": (
|
||||
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||
@ -63,15 +62,6 @@ class TrainingArguments(transformers.TrainingArguments):
|
||||
)
|
||||
|
||||
|
||||
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
||||
"""Collects the state dict and dump to disk."""
|
||||
state_dict = trainer.model.state_dict()
|
||||
if trainer.args.should_save:
|
||||
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
||||
del state_dict
|
||||
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
||||
|
||||
|
||||
def smart_tokenizer_and_embedding_resize(
|
||||
special_tokens_dict: Dict,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
@ -205,26 +195,27 @@ def train():
|
||||
padding_side="right",
|
||||
use_fast=False,
|
||||
)
|
||||
special_tokens_dict = dict()
|
||||
if tokenizer.pad_token is None:
|
||||
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
|
||||
if tokenizer.eos_token is None:
|
||||
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
|
||||
if tokenizer.bos_token is None:
|
||||
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
|
||||
if tokenizer.unk_token is None:
|
||||
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
|
||||
|
||||
smart_tokenizer_and_embedding_resize(
|
||||
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
|
||||
special_tokens_dict=special_tokens_dict,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
)
|
||||
if "llama" in model_args.model_name_or_path:
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"eos_token": DEFAULT_EOS_TOKEN,
|
||||
"bos_token": DEFAULT_BOS_TOKEN,
|
||||
"unk_token": DEFAULT_UNK_TOKEN,
|
||||
}
|
||||
)
|
||||
|
||||
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
|
||||
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
|
||||
trainer.train()
|
||||
trainer.save_state()
|
||||
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
|
||||
trainer.save_model(output_dir=training_args.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user