add seq2seq finetuning example

This commit is contained in:
WANG Yue 2023-06-05 19:59:27 +08:00
parent 430d7e358e
commit a459c24048
2 changed files with 157 additions and 0 deletions

View File

@ -58,6 +58,16 @@ outputs = model.generate(**encoding, max_length=15)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
# How to Finetune Using Your Own Data?
We provide an example finetuning script `tune_codet5p_seq2seq.py` for CodeT5+ models on Seq2Seq LM task.
After installing the `transformers` and `datasets` libraries, you can run `python tune_codet5p_seq2seq.py` to finetune CodeT5+ models on any Seq2Seq LM tasks such as Python code summarization.
To finetune on your own data, you just need to prepare your customized cache data in the `datasets` format and pass its path to `--cache-data`.
Besides, you can specify `--load` to specify the CodeT5+ model (such as `Salesforce/codet5p-220m`) to finetune from. To optimize the hyperparameter setting that suit your task best, you can customize other finetuning arguments such as `--epochs`, `--lr`, `--lr-warmup-steps`, `--max-source-len`, `--max-target-len`, `--batch-size-per-replica`, `--grad-acc-steps`, etc.
This script supports multi-GPU training and mixed-precision training by specifying `--fp16`. If you have limited GPU memory issue, consider to use [DeepSpeed](https://github.com/microsoft/DeepSpeed) by passing a deedspeed config file after `--deepspeed` (see [here](https://huggingface.co/docs/transformers/main_classes/deepspeed#zero2-example) for an example config file).
# Reproduce the Results
## HumanEval

View File

@ -0,0 +1,147 @@
"""
Finetune CodeT5+ models on any Seq2Seq LM tasks
You can customize your own training data by following the HF dataset format to cache it to args.cache_data
Author: Yue Wang
Date: June 2023
"""
import os
import pprint
import argparse
from datasets import load_dataset, load_from_disk
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer
def run_training(args, model, train_data):
print(f"Starting main loop")
training_args = TrainingArguments(
report_to='tensorboard',
output_dir=args.save_dir,
overwrite_output_dir=False,
do_train=True,
save_strategy='epoch',
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size_per_replica,
gradient_accumulation_steps=args.grad_acc_steps,
learning_rate=args.lr,
weight_decay=0.05,
warmup_steps=args.lr_warmup_steps,
logging_dir=args.save_dir,
logging_first_step=True,
logging_steps=args.log_freq,
save_total_limit=1,
dataloader_drop_last=True,
dataloader_num_workers=4,
local_rank=args.local_rank,
deepspeed=args.deepspeed,
fp16=args.fp16,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_data,
)
trainer.train()
if args.local_rank in [0, -1]:
final_checkpoint_dir = os.path.join(args.save_dir, "final_checkpoint")
model.save_pretrained(final_checkpoint_dir)
print(f' ==> Finish training and save to {final_checkpoint_dir}')
def load_tokenize_data(args):
# Load and tokenize data
if os.path.exists(args.cache_data):
train_data = load_from_disk(args.cache_data)
print(f' ==> Loaded {len(train_data)} samples')
return train_data
else:
# Example code to load and process code_x_glue_ct_code_to_text python dataset for code summarization task
datasets = load_dataset("code_x_glue_ct_code_to_text", 'python', split="train")
tokenizer = AutoTokenizer.from_pretrained(args.load)
def preprocess_function(examples):
source = [' '.join(ex) for ex in examples["code_tokens"]]
target = [' '.join(ex) for ex in examples["docstring_tokens"]]
model_inputs = tokenizer(source, max_length=args.max_source_len, padding="max_length", truncation=True)
labels = tokenizer(target, max_length=args.max_target_len, padding="max_length", truncation=True)
model_inputs["labels"] = labels["input_ids"].copy()
model_inputs["labels"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in model_inputs["labels"]
]
return model_inputs
train_data = datasets.map(
preprocess_function,
batched=True,
remove_columns=datasets.column_names,
num_proc=64,
load_from_cache_file=False,
)
print(f' ==> Loaded {len(train_data)} samples')
train_data.save_to_disk(args.cache_data)
print(f' ==> Saved to {args.cache_data}')
return train_data
def main(args):
argsdict = vars(args)
print(pprint.pformat(argsdict))
# Save command to file
with open(os.path.join(args.save_dir, "command.txt"), 'w') as f:
f.write(pprint.pformat(argsdict))
# Load and tokenize data using the tokenizer from `args.load`. If the data is already cached, load it from there.
# You can customize this function to load your own data for any Seq2Seq LM tasks.
train_data = load_tokenize_data(args)
if args.data_num != -1:
train_data = train_data.select([i for i in range(args.data_num)])
# Load model from `args.load`
model = AutoModelForSeq2SeqLM.from_pretrained(args.load)
print(f" ==> Loaded model from {args.load}, model size {model.num_parameters()}")
run_training(args, model, train_data)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="CodeT5+ finetuning on Seq2Seq LM task")
parser.add_argument('--data-num', default=-1, type=int)
parser.add_argument('--max-source-len', default=320, type=int)
parser.add_argument('--max-target-len', default=128, type=int)
parser.add_argument('--cache-data', default='cache_data/summarize_python', type=str)
parser.add_argument('--load', default='Salesforce/codet5p-220m', type=str)
# Training
parser.add_argument('--epochs', default=10, type=int)
parser.add_argument('--lr', default=5e-5, type=float)
parser.add_argument('--lr-warmup-steps', default=200, type=int)
parser.add_argument('--batch-size-per-replica', default=8, type=int)
parser.add_argument('--grad-acc-steps', default=4, type=int)
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--deepspeed', default=None, type=str)
parser.add_argument('--fp16', default=False, action='store_true')
# Logging and stuff
parser.add_argument('--save-dir', default="saved_models/summarize_python", type=str)
parser.add_argument('--log-freq', default=10, type=int)
parser.add_argument('--save-freq', default=500, type=int)
args = parser.parse_args()
os.makedirs(args.save_dir, exist_ok=True)
main(args)