# coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa). GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned using a masked language modeling (MLM) loss. """ import os import logging import argparse import math import numpy as np from tqdm import tqdm import multiprocessing import time import torch from torch.utils.tensorboard import SummaryWriter from torch.utils.data import DataLoader, SequentialSampler, RandomSampler from torch.utils.data.distributed import DistributedSampler from transformers import AdamW, get_linear_schedule_with_warmup from models import build_or_load_gen_model from evaluator import smooth_bleu from evaluator.CodeBLEU import calc_code_bleu from evaluator.bleu import _bleu from utils import get_filenames, get_elapse_time, load_and_cache_gen_data from configs import add_args, set_seed, set_dist logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger = logging.getLogger(__name__) def eval_ppl_epoch(args, eval_data, eval_examples, model, tokenizer): eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size, num_workers=4, pin_memory=True) # Start evaluating model logger.info(" " + "***** Running ppl evaluation *****") logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Batch size = %d", args.eval_batch_size) model.eval() eval_loss, batch_num = 0, 0 for batch in tqdm(eval_dataloader, total=len(eval_dataloader), desc="Eval ppl"): batch = tuple(t.to(args.device) for t in batch) source_ids, target_ids = batch source_mask = source_ids.ne(tokenizer.pad_token_id) target_mask = target_ids.ne(tokenizer.pad_token_id) with torch.no_grad(): if args.model_type == 'roberta': loss, _, _ = model(source_ids=source_ids, source_mask=source_mask, target_ids=target_ids, target_mask=target_mask) else: outputs = model(input_ids=source_ids, attention_mask=source_mask, labels=target_ids, decoder_attention_mask=target_mask) loss = outputs.loss eval_loss += loss.item() batch_num += 1 eval_loss = eval_loss / batch_num eval_ppl = round(np.exp(eval_loss), 5) return eval_ppl def eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, split_tag, criteria): logger.info(" ***** Running bleu evaluation on {} data*****".format(split_tag)) logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Batch size = %d", args.eval_batch_size) eval_sampler = SequentialSampler(eval_data) if args.data_num == -1: eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size, num_workers=4, pin_memory=True) else: eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) model.eval() pred_ids = [] bleu, codebleu = 0.0, 0.0 for batch in tqdm(eval_dataloader, total=len(eval_dataloader), desc="Eval bleu for {} set".format(split_tag)): source_ids = batch[0].to(args.device) source_mask = source_ids.ne(tokenizer.pad_token_id) with torch.no_grad(): if args.model_type == 'roberta': preds = model(source_ids=source_ids, source_mask=source_mask) top_preds = [pred[0].cpu().numpy() for pred in preds] else: preds = model.generate(source_ids, attention_mask=source_mask, use_cache=True, num_beams=args.beam_size, early_stopping=args.task == 'summarize', max_length=args.max_target_length) top_preds = list(preds.cpu().numpy()) pred_ids.extend(top_preds) pred_nls = [tokenizer.decode(id, skip_special_tokens=True, clean_up_tokenization_spaces=False) for id in pred_ids] output_fn = os.path.join(args.res_dir, "test_{}.output".format(criteria)) gold_fn = os.path.join(args.res_dir, "test_{}.gold".format(criteria)) src_fn = os.path.join(args.res_dir, "test_{}.src".format(criteria)) if args.task in ['defect']: target_dict = {0: 'false', 1: 'true'} golds = [target_dict[ex.target] for ex in eval_examples] eval_acc = np.mean([int(p == g) for p, g in zip(pred_nls, golds)]) result = {'em': eval_acc * 100, 'bleu': 0, 'codebleu': 0} with open(output_fn, 'w') as f, open(gold_fn, 'w') as f1, open(src_fn, 'w') as f2: for pred_nl, gold in zip(pred_nls, eval_examples): f.write(pred_nl.strip() + '\n') f1.write(target_dict[gold.target] + '\n') f2.write(gold.source.strip() + '\n') logger.info("Save the predictions into %s", output_fn) else: dev_accs, predictions = [], [] with open(output_fn, 'w') as f, open(gold_fn, 'w') as f1, open(src_fn, 'w') as f2: for pred_nl, gold in zip(pred_nls, eval_examples): dev_accs.append(pred_nl.strip() == gold.target.strip()) if args.task in ['summarize']: # for smooth-bleu4 evaluation predictions.append(str(gold.idx) + '\t' + pred_nl) f.write(str(gold.idx) + '\t' + pred_nl.strip() + '\n') f1.write(str(gold.idx) + '\t' + gold.target.strip() + '\n') f2.write(str(gold.idx) + '\t' + gold.source.strip() + '\n') else: f.write(pred_nl.strip() + '\n') f1.write(gold.target.strip() + '\n') f2.write(gold.source.strip() + '\n') if args.task == 'summarize': (goldMap, predictionMap) = smooth_bleu.computeMaps(predictions, gold_fn) bleu = round(smooth_bleu.bleuFromMaps(goldMap, predictionMap)[0], 2) else: bleu = round(_bleu(gold_fn, output_fn), 2) if args.task in ['concode', 'translate', 'refine']: codebleu = calc_code_bleu.get_codebleu(gold_fn, output_fn, args.lang) result = {'em': np.mean(dev_accs) * 100, 'bleu': bleu} if args.task == 'concode': result['codebleu'] = codebleu * 100 logger.info("***** Eval results *****") for key in sorted(result.keys()): logger.info(" %s = %s", key, str(round(result[key], 4))) return result def main(): parser = argparse.ArgumentParser() args = add_args(parser) logger.info(args) t0 = time.time() set_dist(args) set_seed(args) config, model, tokenizer = build_or_load_gen_model(args) model.to(args.device) if args.n_gpu > 1: # for DataParallel model = torch.nn.DataParallel(model) pool = multiprocessing.Pool(args.cpu_cont) args.train_filename, args.dev_filename, args.test_filename = get_filenames(args.data_dir, args.task, args.sub_task) fa = open(os.path.join(args.output_dir, 'summary.log'), 'a+') if args.do_train: if args.local_rank in [-1, 0] and args.data_num == -1: summary_fn = '{}/{}'.format(args.summary_dir, '/'.join(args.output_dir.split('/')[1:])) tb_writer = SummaryWriter(summary_fn) # Prepare training data loader train_examples, train_data = load_and_cache_gen_data(args, args.train_filename, pool, tokenizer, 'train') train_sampler = RandomSampler(train_data) if args.local_rank == -1 else DistributedSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size, num_workers=4, pin_memory=True) # Prepare optimizer and schedule (linear warmup and decay) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) num_train_optimization_steps = args.num_train_epochs * len(train_dataloader) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=num_train_optimization_steps) # Start training train_example_num = len(train_data) logger.info("***** Running training *****") logger.info(" Num examples = %d", train_example_num) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Batch num = %d", math.ceil(train_example_num / args.train_batch_size)) logger.info(" Num epoch = %d", args.num_train_epochs) dev_dataset = {} global_step, best_bleu_em, best_ppl = 0, -1, 1e6 not_loss_dec_cnt, not_bleu_em_inc_cnt = 0, 0 if args.do_eval_bleu else 1e6 for cur_epoch in range(args.start_epoch, int(args.num_train_epochs)): bar = tqdm(train_dataloader, total=len(train_dataloader), desc="Training") nb_tr_examples, nb_tr_steps, tr_loss = 0, 0, 0 model.train() for step, batch in enumerate(bar): batch = tuple(t.to(args.device) for t in batch) source_ids, target_ids = batch source_mask = source_ids.ne(tokenizer.pad_token_id) target_mask = target_ids.ne(tokenizer.pad_token_id) if args.model_type == 'roberta': loss, _, _ = model(source_ids=source_ids, source_mask=source_mask, target_ids=target_ids, target_mask=target_mask) else: outputs = model(input_ids=source_ids, attention_mask=source_mask, labels=target_ids, decoder_attention_mask=target_mask) loss = outputs.loss if args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps tr_loss += loss.item() nb_tr_examples += source_ids.size(0) nb_tr_steps += 1 loss.backward() if nb_tr_steps % args.gradient_accumulation_steps == 0: # Update parameters optimizer.step() optimizer.zero_grad() scheduler.step() global_step += 1 train_loss = round(tr_loss * args.gradient_accumulation_steps / (nb_tr_steps + 1), 4) bar.set_description("[{}] Train loss {}".format(cur_epoch, round(train_loss, 3))) if args.do_eval: # Eval model with dev dataset if 'dev_loss' in dev_dataset: eval_examples, eval_data = dev_dataset['dev_loss'] else: eval_examples, eval_data = load_and_cache_gen_data(args, args.dev_filename, pool, tokenizer, 'dev') dev_dataset['dev_loss'] = eval_examples, eval_data eval_ppl = eval_ppl_epoch(args, eval_data, eval_examples, model, tokenizer) result = {'epoch': cur_epoch, 'global_step': global_step, 'eval_ppl': eval_ppl} for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) logger.info(" " + "*" * 20) if args.data_num == -1: tb_writer.add_scalar('dev_ppl', eval_ppl, cur_epoch) # save last checkpoint if args.save_last_checkpoints: last_output_dir = os.path.join(args.output_dir, 'checkpoint-last') if not os.path.exists(last_output_dir): os.makedirs(last_output_dir) model_to_save = model.module if hasattr(model, 'module') else model output_model_file = os.path.join(last_output_dir, "pytorch_model.bin") torch.save(model_to_save.state_dict(), output_model_file) logger.info("Save the last model into %s", output_model_file) if eval_ppl < best_ppl: not_loss_dec_cnt = 0 logger.info(" Best ppl:%s", eval_ppl) logger.info(" " + "*" * 20) fa.write("[%d] Best ppl changed into %.4f\n" % (cur_epoch, eval_ppl)) best_ppl = eval_ppl # Save best checkpoint for best ppl output_dir = os.path.join(args.output_dir, 'checkpoint-best-ppl') if not os.path.exists(output_dir): os.makedirs(output_dir) if args.always_save_model: model_to_save = model.module if hasattr(model, 'module') else model output_model_file = os.path.join(output_dir, "pytorch_model.bin") torch.save(model_to_save.state_dict(), output_model_file) logger.info("Save the best ppl model into %s", output_model_file) else: not_loss_dec_cnt += 1 logger.info("Ppl does not decrease for %d epochs", not_loss_dec_cnt) if all([x > args.patience for x in [not_bleu_em_inc_cnt, not_loss_dec_cnt]]): early_stop_str = "[%d] Early stop as not_bleu_em_inc_cnt=%d, and not_loss_dec_cnt=%d\n" % ( cur_epoch, not_bleu_em_inc_cnt, not_loss_dec_cnt) logger.info(early_stop_str) fa.write(early_stop_str) break logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_eval_bleu: eval_examples, eval_data = load_and_cache_gen_data(args, args.dev_filename, pool, tokenizer, 'dev', only_src=True, is_sample=True) result = eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, 'dev', 'e%d' % cur_epoch) dev_bleu, dev_em = result['bleu'], result['em'] if args.task in ['summarize']: dev_bleu_em = dev_bleu elif args.task in ['defect']: dev_bleu_em = dev_em else: dev_bleu_em = dev_bleu + dev_em if args.data_num == -1: tb_writer.add_scalar('dev_bleu_em', dev_bleu_em, cur_epoch) # tb_writer.add_scalar('dev_em', dev_em, cur_epoch) if dev_bleu_em > best_bleu_em: not_bleu_em_inc_cnt = 0 logger.info(" [%d] Best bleu+em: %.2f (bleu: %.2f, em: %.2f)", cur_epoch, dev_bleu_em, dev_bleu, dev_em) logger.info(" " + "*" * 20) best_bleu_em = dev_bleu_em fa.write("[%d] Best bleu+em changed into %.2f (bleu: %.2f, em: %.2f)\n" % ( cur_epoch, best_bleu_em, dev_bleu, dev_em)) # Save best checkpoint for best bleu output_dir = os.path.join(args.output_dir, 'checkpoint-best-bleu') if not os.path.exists(output_dir): os.makedirs(output_dir) if args.data_num == -1 or args.always_save_model: model_to_save = model.module if hasattr(model, 'module') else model output_model_file = os.path.join(output_dir, "pytorch_model.bin") torch.save(model_to_save.state_dict(), output_model_file) logger.info("Save the best bleu model into %s", output_model_file) else: not_bleu_em_inc_cnt += 1 logger.info("Bleu does not increase for %d epochs", not_bleu_em_inc_cnt) fa.write( "[%d] Best bleu+em (%.2f) does not drop changed for %d epochs, cur bleu+em: %.2f (bleu: %.2f, em: %.2f)\n" % ( cur_epoch, best_bleu_em, not_bleu_em_inc_cnt, dev_bleu_em, dev_bleu, dev_em)) if all([x > args.patience for x in [not_bleu_em_inc_cnt, not_loss_dec_cnt]]): stop_early_str = "[%d] Early stop as not_bleu_em_inc_cnt=%d, and not_loss_dec_cnt=%d\n" % ( cur_epoch, not_bleu_em_inc_cnt, not_loss_dec_cnt) logger.info(stop_early_str) fa.write(stop_early_str) break logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.local_rank in [-1, 0] and args.data_num == -1: tb_writer.close() logger.info("Finish training and take %s", get_elapse_time(t0)) if args.do_test: logger.info(" " + "***** Testing *****") logger.info(" Batch size = %d", args.eval_batch_size) for criteria in ['best-bleu']: file = os.path.join(args.output_dir, 'checkpoint-{}/pytorch_model.bin'.format(criteria)) logger.info("Reload model from {}".format(file)) model.load_state_dict(torch.load(file)) eval_examples, eval_data = load_and_cache_gen_data(args, args.test_filename, pool, tokenizer, 'test', only_src=True, is_sample=False) result = eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, 'test', criteria) test_bleu, test_em = result['bleu'], result['em'] test_codebleu = result['codebleu'] if 'codebleu' in result else 0 result_str = "[%s] bleu-4: %.2f, em: %.4f, codebleu: %.4f\n" % (criteria, test_bleu, test_em, test_codebleu) logger.info(result_str) fa.write(result_str) if args.res_fn: with open(args.res_fn, 'a+') as f: f.write('[Time: {}] {}\n'.format(get_elapse_time(t0), file)) f.write(result_str) logger.info("Finish and take {}".format(get_elapse_time(t0))) fa.write("Finish and take {}".format(get_elapse_time(t0))) fa.close() if __name__ == "__main__": main()