mirror of
https://github.com/salesforce/CodeT5.git
synced 2024-10-01 06:35:38 -04:00
240 lines
11 KiB
Python
240 lines
11 KiB
Python
import pdb
|
||
from torch.nn.init import xavier_uniform_
|
||
from torch.utils.data import TensorDataset
|
||
import numpy as np
|
||
import logging
|
||
import os
|
||
import random
|
||
import torch
|
||
import time
|
||
from tqdm import tqdm
|
||
from _utils import *
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def load_and_cache_gen_data(args, filename, pool, tokenizer, split_tag, only_src=False, is_sample=False):
|
||
# cache the data into args.cache_path except it is sampled
|
||
# only_src: control whether to return only source ids for bleu evaluating (dev/test)
|
||
# return: examples (Example object), data (TensorDataset)
|
||
data_tag = '_all' if args.data_num == -1 else '_%d' % args.data_num
|
||
cache_fn = '{}/{}.pt'.format(args.cache_path, split_tag + ('_src' if only_src else '') + data_tag)
|
||
|
||
examples = read_examples(filename, args.data_num, args.task)
|
||
|
||
if is_sample:
|
||
examples = random.sample(examples, min(5000, len(examples)))
|
||
if split_tag == 'train':
|
||
calc_stats(examples, tokenizer, is_tokenize=True)
|
||
else:
|
||
calc_stats(examples)
|
||
if os.path.exists(cache_fn) and not is_sample:
|
||
logger.info("Load cache data from %s", cache_fn)
|
||
data = torch.load(cache_fn)
|
||
else:
|
||
if is_sample:
|
||
logger.info("Sample 5k data for computing bleu from %s", filename)
|
||
else:
|
||
logger.info("Create cache data into %s", cache_fn)
|
||
tuple_examples = [(example, idx, tokenizer, args, split_tag) for idx, example in enumerate(examples)]
|
||
features = pool.map(convert_examples_to_features, tqdm(tuple_examples, total=len(tuple_examples)))
|
||
all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long)
|
||
if split_tag == 'test' or only_src:
|
||
data = TensorDataset(all_source_ids)
|
||
else:
|
||
all_target_ids = torch.tensor([f.target_ids for f in features], dtype=torch.long)
|
||
data = TensorDataset(all_source_ids, all_target_ids)
|
||
if args.local_rank in [-1, 0] and not is_sample:
|
||
torch.save(data, cache_fn)
|
||
return examples, data
|
||
|
||
|
||
def load_and_cache_multi_gen_data(args, split_tag, pool, tokenizer, encode_target=True, is_sample=False):
|
||
cache_fn = os.path.join(args.cache_path, split_tag)
|
||
if os.path.exists(cache_fn) and not is_sample:
|
||
logger.info("Load cache data from %s", cache_fn)
|
||
examples_data_dict = torch.load(cache_fn)
|
||
else:
|
||
examples_data_dict = {}
|
||
|
||
task_list = ['summarize', 'translate', 'refine', 'concode', 'defect']
|
||
for task in task_list:
|
||
if task == 'summarize':
|
||
sub_tasks = ['ruby', 'javascript', 'go', 'python', 'java', 'php']
|
||
elif task == 'translate':
|
||
sub_tasks = ['java-cs', 'cs-java']
|
||
elif task == 'refine':
|
||
sub_tasks = ['small', 'medium']
|
||
else:
|
||
sub_tasks = ['none']
|
||
args.task = task
|
||
for sub_task in sub_tasks:
|
||
args.sub_task = sub_task
|
||
if task == 'summarize':
|
||
args.max_source_length = 256
|
||
args.max_target_length = 128
|
||
elif task == 'translate':
|
||
args.max_source_length = 320
|
||
args.max_target_length = 256
|
||
elif task == 'refine':
|
||
if sub_task == 'small':
|
||
args.max_source_length = 130
|
||
args.max_target_length = 120
|
||
else:
|
||
args.max_source_length = 240
|
||
args.max_target_length = 240
|
||
elif task == 'concode':
|
||
args.max_source_length = 320
|
||
args.max_target_length = 150
|
||
elif task == 'defect':
|
||
args.max_source_length = 512
|
||
args.max_target_length = 3 # as do not need to add lang ids
|
||
|
||
filename = get_filenames(args.data_dir, args.task, args.sub_task, split_tag)
|
||
examples = read_examples(filename, args.data_num, args.task)
|
||
if is_sample:
|
||
examples = random.sample(examples, min(5000, len(examples)))
|
||
if split_tag == 'train':
|
||
calc_stats(examples, tokenizer, is_tokenize=True)
|
||
else:
|
||
calc_stats(examples)
|
||
|
||
tuple_examples = [(example, idx, tokenizer, args, split_tag) for idx, example in enumerate(examples)]
|
||
if args.data_num == -1:
|
||
features = pool.map(convert_examples_to_features, tqdm(tuple_examples, total=len(tuple_examples)))
|
||
else:
|
||
features = [convert_examples_to_features(x) for x in tuple_examples]
|
||
all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long)
|
||
if encode_target:
|
||
all_target_ids = torch.tensor([f.target_ids for f in features], dtype=torch.long)
|
||
data = TensorDataset(all_source_ids, all_target_ids)
|
||
else:
|
||
data = TensorDataset(all_source_ids)
|
||
examples_data_dict['{}_{}'.format(task, sub_task) if sub_task != 'none' else task] = (examples, data)
|
||
|
||
if args.local_rank in [-1, 0] and not is_sample:
|
||
torch.save(examples_data_dict, cache_fn)
|
||
logger.info("Save data into %s", cache_fn)
|
||
return examples_data_dict
|
||
|
||
|
||
def load_and_cache_clone_data(args, filename, pool, tokenizer, split_tag, is_sample=False):
|
||
cache_fn = '{}/{}.pt'.format(args.cache_path, split_tag + '_all' if args.data_num == -1 else '_%d' % args.data_num)
|
||
examples = read_examples(filename, args.data_num, args.task)
|
||
if is_sample:
|
||
examples = random.sample(examples, int(len(examples) * 0.1))
|
||
|
||
calc_stats(examples, tokenizer, is_tokenize=True)
|
||
if os.path.exists(cache_fn):
|
||
logger.info("Load cache data from %s", cache_fn)
|
||
data = torch.load(cache_fn)
|
||
else:
|
||
if is_sample:
|
||
logger.info("Sample 10 percent of data from %s", filename)
|
||
elif args.data_num == -1:
|
||
logger.info("Create cache data into %s", cache_fn)
|
||
tuple_examples = [(example, idx, tokenizer, args) for idx, example in enumerate(examples)]
|
||
features = pool.map(convert_clone_examples_to_features, tqdm(tuple_examples, total=len(tuple_examples)))
|
||
# features = [convert_clone_examples_to_features(x) for x in tuple_examples]
|
||
all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long)
|
||
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
||
data = TensorDataset(all_source_ids, all_labels)
|
||
|
||
if args.local_rank in [-1, 0] and args.data_num == -1:
|
||
torch.save(data, cache_fn)
|
||
return examples, data
|
||
|
||
|
||
def get_filenames(data_root, task, sub_task, split=''):
|
||
if task == 'concode':
|
||
data_dir = '{}/{}'.format(data_root, task)
|
||
train_fn = '{}/train.json'.format(data_dir)
|
||
dev_fn = '{}/dev.json'.format(data_dir)
|
||
test_fn = '{}/test.json'.format(data_dir)
|
||
elif task == 'summarize':
|
||
data_dir = '{}/{}/{}'.format(data_root, task, sub_task)
|
||
train_fn = '{}/train.jsonl'.format(data_dir)
|
||
dev_fn = '{}/valid.jsonl'.format(data_dir)
|
||
test_fn = '{}/test.jsonl'.format(data_dir)
|
||
elif task == 'refine':
|
||
data_dir = '{}/{}/{}'.format(data_root, task, sub_task)
|
||
train_fn = '{}/train.buggy-fixed.buggy,{}/train.buggy-fixed.fixed'.format(data_dir, data_dir)
|
||
dev_fn = '{}/valid.buggy-fixed.buggy,{}/valid.buggy-fixed.fixed'.format(data_dir, data_dir)
|
||
test_fn = '{}/test.buggy-fixed.buggy,{}/test.buggy-fixed.fixed'.format(data_dir, data_dir)
|
||
elif task == 'translate':
|
||
data_dir = '{}/{}'.format(data_root, task)
|
||
if sub_task == 'cs-java':
|
||
train_fn = '{}/train.java-cs.txt.cs,{}/train.java-cs.txt.java'.format(data_dir, data_dir)
|
||
dev_fn = '{}/valid.java-cs.txt.cs,{}/valid.java-cs.txt.java'.format(data_dir, data_dir)
|
||
test_fn = '{}/test.java-cs.txt.cs,{}/test.java-cs.txt.java'.format(data_dir, data_dir)
|
||
else:
|
||
train_fn = '{}/train.java-cs.txt.java,{}/train.java-cs.txt.cs'.format(data_dir, data_dir)
|
||
dev_fn = '{}/valid.java-cs.txt.java,{}/valid.java-cs.txt.cs'.format(data_dir, data_dir)
|
||
test_fn = '{}/test.java-cs.txt.java,{}/test.java-cs.txt.cs'.format(data_dir, data_dir)
|
||
elif task == 'clone':
|
||
data_dir = '{}/{}'.format(data_root, task)
|
||
train_fn = '{}/train.txt'.format(data_dir)
|
||
dev_fn = '{}/valid.txt'.format(data_dir)
|
||
test_fn = '{}/test.txt'.format(data_dir)
|
||
elif task == 'defect':
|
||
data_dir = '{}/{}'.format(data_root, task)
|
||
train_fn = '{}/train.jsonl'.format(data_dir)
|
||
dev_fn = '{}/valid.jsonl'.format(data_dir)
|
||
test_fn = '{}/test.jsonl'.format(data_dir)
|
||
if split == 'train':
|
||
return train_fn
|
||
elif split == 'dev':
|
||
return dev_fn
|
||
elif split == 'test':
|
||
return test_fn
|
||
else:
|
||
return train_fn, dev_fn, test_fn
|
||
|
||
|
||
def read_examples(filename, data_num, task):
|
||
read_example_dict = {
|
||
'summarize': read_summarize_examples, # read_summarize_examples, read_summarize_indent_examples
|
||
'refine': read_refine_examples,
|
||
'translate': read_translate_examples,
|
||
'concode': read_concode_examples,
|
||
'clone': read_clone_examples,
|
||
'defect': read_defect_examples,
|
||
}
|
||
return read_example_dict[task](filename, data_num)
|
||
|
||
|
||
def calc_stats(examples, tokenizer=None, is_tokenize=False):
|
||
avg_src_len = []
|
||
avg_trg_len = []
|
||
avg_src_len_tokenize = []
|
||
avg_trg_len_tokenize = []
|
||
for ex in examples:
|
||
if is_tokenize:
|
||
avg_src_len.append(len(ex.source.split()))
|
||
avg_trg_len.append(len(str(ex.target).split()))
|
||
avg_src_len_tokenize.append(len(tokenizer.tokenize(ex.source)))
|
||
avg_trg_len_tokenize.append(len(tokenizer.tokenize(str(ex.target))))
|
||
else:
|
||
avg_src_len.append(len(ex.source.split()))
|
||
avg_trg_len.append(len(str(ex.target).split()))
|
||
if is_tokenize:
|
||
logger.info("Read %d examples, avg src len: %d, avg trg len: %d, max src len: %d, max trg len: %d",
|
||
len(examples), np.mean(avg_src_len), np.mean(avg_trg_len), max(avg_src_len), max(avg_trg_len))
|
||
logger.info("[TOKENIZE] avg src len: %d, avg trg len: %d, max src len: %d, max trg len: %d",
|
||
np.mean(avg_src_len_tokenize), np.mean(avg_trg_len_tokenize), max(avg_src_len_tokenize),
|
||
max(avg_trg_len_tokenize))
|
||
else:
|
||
logger.info("Read %d examples, avg src len: %d, avg trg len: %d, max src len: %d, max trg len: %d",
|
||
len(examples), np.mean(avg_src_len), np.mean(avg_trg_len), max(avg_src_len), max(avg_trg_len))
|
||
|
||
|
||
def get_elapse_time(t0):
|
||
elapse_time = time.time() - t0
|
||
if elapse_time > 3600:
|
||
hour = int(elapse_time // 3600)
|
||
minute = int((elapse_time % 3600) // 60)
|
||
return "{}h{}m".format(hour, minute)
|
||
else:
|
||
minute = int((elapse_time % 3600) // 60)
|
||
return "{}m".format(minute)
|