diff --git a/README.md b/README.md
index ced9de5..3b208c2 100644
--- a/README.md
+++ b/README.md
@@ -91,7 +91,7 @@ The inner circle of the plot represents the root verb of the instructions, and t
[](./assets/parse_analysis.png)
## Fine-tuning
-We fine-tune our model using standard Hugging Face training code with the following hyperparameters:
+We fine-tune our models using standard Hugging Face training code with the following hyperparameters:
| Hyperparameter | Value |
|----------------|-------|
@@ -101,7 +101,68 @@ We fine-tune our model using standard Hugging Face training code with the follow
| Max length | 512 |
| Weight decay | 1 |
-We are waiting for Hugging Face to officially support the llama models (i.e. this [PR](https://github.com/huggingface/transformers/pull/21955) to be merged) before we release a stable version of the finetuning code.
+Given Hugging Face hasn't officially supported the LLaMA models, we fine-tuned LLaMA with Hugging Face's transformers library by installing it from a particular fork (i.e. this [PR](https://github.com/huggingface/transformers/pull/21955) to be merged).
+The hash of the specific commit we installed was `68d640f7c368bcaaaecfc678f11908ebbd3d6176`.
+
+To reproduce our fine-tuning runs for LLaMA, first install the requirements
+```bash
+pip install -r requirements.txt
+```
+Then, install the particular fork of Hugging Face's transformers library.
+
+Below is a command that fine-tunes LLaMA-7B with our dataset on a machine with 4 A100 80G GPUs in FSDP `full_shard` mode.
+Replace `` with a port of your own, `` with the
+path to your converted checkpoint and tokenizer (following instructions in the PR), and `` with where you want to store your outputs.
+
+```
+torchrun --nproc_per_node=4 --master_port= train.py \
+ --model_name_or_path \
+ --data_path ./alpaca_data.json \
+ --bf16 True \
+ --output_dir \
+ --num_train_epochs 3 \
+ --per_device_train_batch_size 4 \
+ --per_device_eval_batch_size 4 \
+ --gradient_accumulation_steps 8 \
+ --evaluation_strategy "no" \
+ --save_strategy "steps" \
+ --save_steps 2000 \
+ --save_total_limit 1 \
+ --learning_rate 2e-5 \
+ --weight_decay 0. \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --fsdp "full_shard auto_wrap" \
+ --fsdp_transformer_layer_cls_to_wrap 'LLaMADecoderLayer' \
+ --tf32 True
+```
+
+The same script also works for OPT fine-tuning. Here's an example for fine-tuning OPT-6.7B
+
+```bash
+torchrun --nproc_per_node=4 --master_port= train.py \
+ --model_name_or_path "facebook/opt-6.7b" \
+ --data_path ./alpaca_data.json \
+ --bf16 True \
+ --output_dir \
+ --num_train_epochs 3 \
+ --per_device_train_batch_size 4 \
+ --per_device_eval_batch_size 4 \
+ --gradient_accumulation_steps 8 \
+ --evaluation_strategy "no" \
+ --save_strategy "steps" \
+ --save_steps 2000 \
+ --save_total_limit 1 \
+ --learning_rate 2e-5 \
+ --weight_decay 0. \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --fsdp "full_shard auto_wrap" \
+ --fsdp_transformer_layer_cls_to_wrap 'OPTDecoderLayer' \
+ --tf32 True
+```
### Authors
All grad students below contributed equally and the order is determined by random draw.
diff --git a/requirements.txt b/requirements.txt
index 5276c49..61c87ab 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,9 @@
numpy
rouge_score
fire
-openai
\ No newline at end of file
+openai
+transformers>=4.26.1
+torch
+sentencepiece
+tokenizers==0.12.1
+wandb
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..0a17a26
--- /dev/null
+++ b/train.py
@@ -0,0 +1,232 @@
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import logging
+from dataclasses import dataclass, field
+from typing import Optional, Dict, Sequence
+
+import torch
+import transformers
+from torch.utils.data import Dataset
+from transformers import Trainer
+
+import utils
+
+IGNORE_INDEX = -100
+DEFAULT_PAD_TOKEN = "[PAD]"
+DEFAULT_EOS_TOKEN = ""
+DEFAULT_BOS_TOKEN = ""
+DEFAULT_UNK_TOKEN = ""
+PROMPT_DICT = {
+ "prompt_input": (
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
+ "Write a response that appropriately completes the request.\n\n"
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
+ ),
+ "prompt_no_input": (
+ "Below is an instruction that describes a task. "
+ "Write a response that appropriately completes the request.\n\n"
+ "### Instruction:\n{instruction}\n\n### Response:"
+ ),
+}
+
+
+@dataclass
+class ModelArguments:
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
+
+
+@dataclass
+class DataArguments:
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
+
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ cache_dir: Optional[str] = field(default=None)
+ optim: str = field(default="adamw_torch")
+ model_max_length: int = field(
+ default=512,
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
+ )
+
+
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
+ """Collects the state dict and dump to disk."""
+ state_dict = trainer.model.state_dict()
+ if trainer.args.should_save:
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def smart_tokenizer_and_embedding_resize(
+ special_tokens_dict: Dict,
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = model.get_input_embeddings().weight.data
+ output_embeddings = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+
+def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
+ """Tokenize a list of strings."""
+ tokenized_list = [
+ tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ )
+ for text in strings
+ ]
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
+ input_ids_lens = labels_lens = [
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
+ ]
+ return dict(
+ input_ids=input_ids,
+ labels=labels,
+ input_ids_lens=input_ids_lens,
+ labels_lens=labels_lens,
+ )
+
+
+def preprocess(
+ sources: Sequence[str],
+ targets: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ """Preprocess the data by tokenizing."""
+ examples = [s + t for s, t in zip(sources, targets)]
+ examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
+ input_ids = examples_tokenized["input_ids"]
+ labels = copy.deepcopy(input_ids)
+ for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
+ label[:source_len] = IGNORE_INDEX
+ return dict(input_ids=input_ids, labels=labels)
+
+
+class SupervisedDataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
+ super(SupervisedDataset, self).__init__()
+ logging.warning("Loading data...")
+ list_data_dict = utils.jload(data_path)
+
+ logging.warning("Formatting inputs...")
+ prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
+ sources = [
+ prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
+ for example in list_data_dict
+ ]
+ targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
+
+ logging.warning("Tokenizing inputs... This may take some time...")
+ data_dict = preprocess(sources, targets, tokenizer)
+
+ self.input_ids = data_dict["input_ids"]
+ self.labels = data_dict["labels"]
+
+ def __len__(self):
+ return len(self.input_ids)
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ return dict(input_ids=self.input_ids[i], labels=self.labels[i])
+
+
+@dataclass
+class DataCollatorForSupervisedDataset(object):
+ """Collate examples for supervised fine-tuning."""
+
+ tokenizer: transformers.PreTrainedTokenizer
+
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
+ )
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
+ return dict(
+ input_ids=input_ids,
+ labels=labels,
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ )
+
+
+def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
+ """Make dataset and collator for supervised fine-tuning."""
+ train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
+
+
+def train():
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ model = transformers.AutoModelForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ )
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right",
+ use_fast=False,
+ )
+ if tokenizer.pad_token is None:
+ smart_tokenizer_and_embedding_resize(
+ special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
+ tokenizer=tokenizer,
+ model=model,
+ )
+ if "llama" in model_args.model_name_or_path:
+ tokenizer.add_special_tokens(
+ {
+ "eos_token": DEFAULT_EOS_TOKEN,
+ "bos_token": DEFAULT_BOS_TOKEN,
+ "unk_token": DEFAULT_UNK_TOKEN,
+ }
+ )
+
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
+ trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
+ trainer.train()
+ trainer.evaluate()
+ trainer.save_state()
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
+
+
+if __name__ == "__main__":
+ train()