mirror of
https://github.com/salesforce/CodeT5.git
synced 2024-10-01 06:35:38 -04:00
add embedding model and code retrieval evaluation
This commit is contained in:
parent
ebf3075b24
commit
2bcbf6b473
@ -24,6 +24,8 @@ Furthermore, we explore instruction tuning to align the model with natural langu
|
||||
3. [Instruction Tuning to Align with Natural Language Instructions](#instruction-tuning-to-align-with-natural-language-instructions)
|
||||
4. [How to Finetune Using Your Own Data?](#how-to-finetune-using-your-own-data)
|
||||
5. [Reproduce the Results](#reproduce-the-results)
|
||||
1. [HumanEval](#humaneval)
|
||||
2. [Text-to-Code Retrieval](#text-to-code-retrieval)
|
||||
6. [Citation](#citation)
|
||||
|
||||
|
||||
@ -34,6 +36,7 @@ InstructCodeT5+ 16B is our instruction-tuned model from CodeT5+ 16B.
|
||||
Note that as this model utilizes instruction tuning data curated using OpenAI API, the checkpoint of InstructCodeT5+ 16B is licensed for research and **non-commercial** use only.
|
||||
|
||||
We release the following CodeT5+ models at Huggingface:
|
||||
* CodeT5+ `110M` embedding model: [codet5p-110m-embedding](https://huggingface.co/Salesforce/codet5p-110m-embedding).
|
||||
* CodeT5+ `220M` and `770M`: [codet5p-220m](https://huggingface.co/Salesforce/codet5p-220m) and [codet5p-770m](https://huggingface.co/Salesforce/codet5p-770m).
|
||||
* CodeT5+ `220M` and `770M` that are further tuned on Python subset: [codet5p-220m-py](https://huggingface.co/Salesforce/codet5p-220m-py) and [codet5p-770m-py](https://huggingface.co/Salesforce/codet5p-770m-py).
|
||||
* CodeT5+ `2B`, `6B`, `16B`: [codet5p-2b](https://huggingface.co/Salesforce/codet5p-2b), [codet5p-6b](https://huggingface.co/Salesforce/codet5p-6b), and [codet5p-16b](https://huggingface.co/Salesforce/codet5p-16b).
|
||||
@ -68,6 +71,24 @@ outputs = model.generate(**encoding, max_length=15)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
### CodeT5+ embedding model 🔥
|
||||
Apart from the generative models, we also release the [CodeT5+ 110M embedding](https://huggingface.co/Salesforce/codet5p-110m-embedding) model that can be used to extract code embeddings. This checkpoint contains an encoder of the CodeT5+ 220M model that are pretrained from two stages on both unimodal and bimodal data, as well as a linear projection layer to map the encoder output to a 256-dimensional vector.
|
||||
|
||||
```python
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
checkpoint = "Salesforce/codet5p-110m-embedding"
|
||||
device = "cuda" # for GPU usage or "cpu" for CPU usage
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True).to(device)
|
||||
|
||||
inputs = tokenizer.encode("def print_hello_world():\tprint('Hello World!')", return_tensors="pt").to(device)
|
||||
embedding = model(inputs)[0]
|
||||
print(f'Dimension of the embedding: {embedding.size()[0]}, with norm={embedding.norm().item()}')
|
||||
# Dimension of the embedding: 256, with norm=1.0
|
||||
```
|
||||
|
||||
# Instruction Tuning to Align with Natural Language Instructions
|
||||
|
||||
We explore instruction tuning to align CodeT5+ with natural language instructions following [Code Alpaca](https://github.com/sahil280114/codealpaca). First download the instruction data `code_alpaca_20k.json` from [here](https://github.com/sahil280114/codealpaca/tree/master/data).
|
||||
@ -182,6 +203,41 @@ It can reproduce the results of `36.1% Pass@1` with the following command.
|
||||
evaluate_functional_correctness humaneval/instructcodet5p-16b_T0.2_N200.jsonl
|
||||
```
|
||||
|
||||
## Text-to-Code Retrieval
|
||||
* Download and preprocess 3 text-to-code retrieval datasets following the instructions in this [repo](https://github.com/microsoft/CodeBERT/tree/master/UniXcoder/downstream-tasks/code-search#data-download).
|
||||
* `cd code_retrieval` then run the evaluation of our CodeT5+ 110M embedding model via `bash run_retrieval.sh`.
|
||||
|
||||
```bash
|
||||
# LANG choices: ruby javascript go python java php AdvTest cosqa
|
||||
LANG=ruby
|
||||
BS=256
|
||||
CODE_LEN=360
|
||||
TEXT_LEN=64
|
||||
MODEL_NAME=Salesforce/codet5p-110m-embedding
|
||||
DATA_DIR=/path/to/data
|
||||
|
||||
TRG_DIR=saved_models/${LANG}/codet5p_110m_embedding_TL${TEXT_LEN}_CL${CODE_LEN}
|
||||
mkdir -p $TRG_DIR
|
||||
echo 'Target dir: '$TRG_DIR
|
||||
|
||||
python eval_contrast_retrieval.py --model_name $MODEL_NAME --lang $LANG --output_dir $TRG_DIR \
|
||||
--data_dir $DATA_DIR --max_text_len $TEXT_LEN --max_code_len $CODE_LEN --batch_size $BS
|
||||
```
|
||||
|
||||
### Zero-shot Evaluation Results
|
||||
|
||||
The above running script can reproduce the results as shown in the `CodeT5+ 110M embedding` row of the following table. We will release the `CodeT5+ 220M matching` model soon, which shares the same encoder as the embedding model. It achieves better performance than the embedding model via leveraging the fine-grained alignment between text and code through the matching decoder.
|
||||
For UniXcoder's zero-shot results, we reproduce it following its official instructions [here](https://github.com/microsoft/CodeBERT/tree/master/UniXcoder/downstream-tasks/code-search#zero-shot-setting).
|
||||
|
||||
|
||||
| Model | Ruby | JavaScript | Go | Python | Java | PHP | CSN_Avg | CosQA | AdvTest |
|
||||
| ---------------------- | ----- | ---------- | ----- | ------ | ----- | ----- | ------- | ----- |--------|
|
||||
| UniXcoder 125M | 57.6 | 44.2 | 64.8 | 44.7 | 46.6 | 37.3 | 49.20 | 43.1 | 29.9 |
|
||||
| CodeT5+ 110M embedding | 74.51 | 69.07 | 90.69 | 71.55 | 71.82 | 67.72 | 74.23 | 39.57 | 40.49 |
|
||||
| CodeT5+ 220M matching | 75.94 | 69.85 | 91.32 | 73.97 | 74.7 | 68.28 | 75.68 | 51.54 | 42.03 |
|
||||
|
||||
* Note that the reported zero-shot results of CodeT5+ are different from the ones in the paper which are task-specific fine-tuned results.
|
||||
|
||||
# Citation
|
||||
|
||||
```bibtex
|
||||
|
261
CodeT5+/code_retrieval/data_utils.py
Normal file
261
CodeT5+/code_retrieval/data_utils.py
Normal file
@ -0,0 +1,261 @@
|
||||
import json
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
|
||||
def create_dataset(data_dir, task):
|
||||
if task == 'AdvTest':
|
||||
train_dataset = csn_search_train(data_dir, task, 'train')
|
||||
val_dataset = advtest_search_eval_text(data_dir, task, 'valid')
|
||||
test_dataset = advtest_search_eval_text(data_dir, task, 'test')
|
||||
codebase_dataset = csn_search_eval_code(data_dir, task, 'test.jsonl')
|
||||
return train_dataset, val_dataset, test_dataset, codebase_dataset
|
||||
elif task == 'cosqa':
|
||||
train_dataset = cosqa_search_train(data_dir, task, 'cosqa-retrieval-train-19604.json')
|
||||
val_dataset = cosqa_search_eval_text(data_dir, task, 'cosqa-retrieval-dev-500.json')
|
||||
test_dataset = cosqa_search_eval_text(data_dir, task, 'cosqa-retrieval-test-500.json')
|
||||
codebase_dataset = cosqa_search_eval_code(data_dir, task)
|
||||
return train_dataset, val_dataset, test_dataset, codebase_dataset
|
||||
else:
|
||||
train_dataset = csn_search_train(data_dir, task, 'train')
|
||||
val_dataset = csn_search_eval_text(data_dir, task, 'valid')
|
||||
test_dataset = csn_search_eval_text(data_dir, task, 'test')
|
||||
codebase_dataset = csn_search_eval_code(data_dir, task, 'codebase.jsonl')
|
||||
return train_dataset, val_dataset, test_dataset, codebase_dataset
|
||||
|
||||
|
||||
def create_sampler(datasets, shuffles, num_tasks, global_rank):
|
||||
samplers = []
|
||||
for dataset, shuffle in zip(datasets, shuffles):
|
||||
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank,
|
||||
shuffle=shuffle)
|
||||
samplers.append(sampler)
|
||||
return samplers
|
||||
|
||||
|
||||
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
|
||||
loaders = []
|
||||
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(datasets, samplers, batch_size, num_workers,
|
||||
is_trains, collate_fns):
|
||||
if is_train:
|
||||
shuffle = (sampler is None)
|
||||
drop_last = True
|
||||
else:
|
||||
shuffle = False
|
||||
drop_last = False
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=bs,
|
||||
num_workers=n_worker,
|
||||
pin_memory=True,
|
||||
sampler=sampler,
|
||||
shuffle=shuffle,
|
||||
collate_fn=collate_fn,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
loaders.append(loader)
|
||||
return loaders
|
||||
|
||||
|
||||
class Example(object):
|
||||
"""A single training/test example."""
|
||||
|
||||
def __init__(self,
|
||||
idx,
|
||||
text,
|
||||
code,
|
||||
url=None
|
||||
):
|
||||
self.idx = idx
|
||||
self.text = text
|
||||
self.code = code
|
||||
self.url = url
|
||||
|
||||
|
||||
# for notice, in case this will cause errors
|
||||
def replace_special_tokens(line):
|
||||
return line.replace('<pad>', '</pad>').replace('<s>', '<ss>').replace('</s>', '</ss>')
|
||||
|
||||
|
||||
def read_search_examples(filename):
|
||||
"""Read examples from filename."""
|
||||
examples = []
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
for idx, line in enumerate(f):
|
||||
line = line.strip()
|
||||
js = json.loads(line)
|
||||
if 'idx' not in js:
|
||||
js['idx'] = idx
|
||||
|
||||
if 'function_tokens' in js:
|
||||
js['code_tokens'] = js['function_tokens']
|
||||
code = replace_special_tokens(' '.join(js['code_tokens']))
|
||||
nl = replace_special_tokens(' '.join(js['docstring_tokens']))
|
||||
examples.append(
|
||||
Example(
|
||||
idx=idx,
|
||||
text=nl,
|
||||
code=code,
|
||||
url=js['url']
|
||||
)
|
||||
)
|
||||
|
||||
print(f'Read {len(examples)} data from {filename}')
|
||||
return examples
|
||||
|
||||
|
||||
def read_cosqa_search_examples(filename):
|
||||
"""Read examples from filename."""
|
||||
examples = []
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
if "code_idx_map" in filename:
|
||||
js = json.load(f)
|
||||
for key in js:
|
||||
examples.append(
|
||||
Example(
|
||||
idx=js[key],
|
||||
text="",
|
||||
code=key,
|
||||
url=js[key]
|
||||
)
|
||||
)
|
||||
else:
|
||||
data = json.load(f)
|
||||
for idx, js in enumerate(data):
|
||||
code = replace_special_tokens(' '.join(js['code_tokens'].split()))
|
||||
nl = replace_special_tokens(' '.join(js['doc'].split()))
|
||||
examples.append(
|
||||
Example(
|
||||
idx=idx,
|
||||
text=nl,
|
||||
code=code,
|
||||
url=js['retrieval_idx']
|
||||
)
|
||||
)
|
||||
|
||||
print(f'Read {len(examples)} data from {filename}')
|
||||
return examples
|
||||
|
||||
|
||||
class csn_search_train(Dataset):
|
||||
def __init__(self, data_dir, lang, split='train'):
|
||||
self.examples = read_search_examples(f'{data_dir}/{lang}/{split}.jsonl')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, index):
|
||||
ex = self.examples[index]
|
||||
return ex.text, ex.code, ex.idx
|
||||
|
||||
|
||||
class csn_search_eval_text(Dataset):
|
||||
def __init__(self, data_dir, lang, split='valid'):
|
||||
self.examples = read_search_examples(f'{data_dir}/{lang}/{split}.jsonl')
|
||||
self.codebase = read_search_examples(f'{data_dir}/{lang}/codebase.jsonl')
|
||||
|
||||
self.text = []
|
||||
self.code = []
|
||||
|
||||
text2url = {}
|
||||
url2code = {}
|
||||
|
||||
for idx, ex in enumerate(self.examples):
|
||||
self.text.append(ex.text)
|
||||
text2url[idx] = ex.url
|
||||
|
||||
for idx, ex in enumerate(self.codebase):
|
||||
self.code.append(ex.code)
|
||||
url2code[ex.url] = idx
|
||||
|
||||
self.text2code = {}
|
||||
|
||||
for text_id, text in enumerate(self.text):
|
||||
self.text2code[text_id] = url2code[text2url[text_id]]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.text)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.text[index]
|
||||
|
||||
|
||||
class advtest_search_eval_text(Dataset):
|
||||
def __init__(self, data_dir, lang, split='valid'):
|
||||
self.examples = read_search_examples(f'{data_dir}/{lang}/{split}.jsonl')
|
||||
|
||||
# below is for advtest
|
||||
self.text2code = {}
|
||||
for ex in self.examples:
|
||||
self.text2code[ex.idx] = ex.idx
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.examples[index].text
|
||||
|
||||
|
||||
class csn_search_eval_code(Dataset):
|
||||
def __init__(self, data_dir, lang, codebase_fn='codebase.jsonl'):
|
||||
self.code = [ex.code for ex in read_search_examples(f'{data_dir}/{lang}/{codebase_fn}')]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.code)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.code[index]
|
||||
|
||||
|
||||
class cosqa_search_train(Dataset):
|
||||
def __init__(self, data_dir, lang, split='train'):
|
||||
self.examples = read_cosqa_search_examples(f'{data_dir}/{lang}/{split}')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, index):
|
||||
ex = self.examples[index]
|
||||
return ex.text, ex.code, ex.idx
|
||||
|
||||
|
||||
class cosqa_search_eval_text(Dataset):
|
||||
def __init__(self, data_dir, lang, split='valid'):
|
||||
self.examples = read_cosqa_search_examples(f'{data_dir}/{lang}/{split}')
|
||||
self.codebase = read_cosqa_search_examples(f'{data_dir}/{lang}/code_idx_map.txt')
|
||||
|
||||
self.text = []
|
||||
self.code = []
|
||||
|
||||
text2url = {}
|
||||
url2code = {}
|
||||
|
||||
for idx, ex in enumerate(self.examples):
|
||||
self.text.append(ex.text)
|
||||
text2url[idx] = ex.url
|
||||
|
||||
for idx, ex in enumerate(self.codebase):
|
||||
self.code.append(ex.code)
|
||||
url2code[ex.url] = idx
|
||||
|
||||
self.text2code = {}
|
||||
|
||||
for text_id, text in enumerate(self.text):
|
||||
self.text2code[text_id] = url2code[text2url[text_id]]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.text)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.text[index]
|
||||
|
||||
|
||||
class cosqa_search_eval_code(Dataset):
|
||||
def __init__(self, data_dir, lang):
|
||||
self.code = [ex.code for ex in read_cosqa_search_examples(f'{data_dir}/{lang}/code_idx_map.txt')]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.code)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.code[index]
|
116
CodeT5+/code_retrieval/eval_contrast_retrieval.py
Normal file
116
CodeT5+/code_retrieval/eval_contrast_retrieval.py
Normal file
@ -0,0 +1,116 @@
|
||||
'''
|
||||
* Copyright (c) 2023, salesforce.com, inc.
|
||||
* All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
* By Yue Wang
|
||||
'''
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from data_utils import create_dataset, create_loader
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_feats(model, tokenizer, data_loader, max_length, device, desc='Get feats'):
|
||||
embeds = []
|
||||
|
||||
for text in tqdm(data_loader, total=len(data_loader), desc=desc):
|
||||
text_input = tokenizer(text, padding='max_length', truncation=True, max_length=max_length,
|
||||
return_tensors="pt").to(device)
|
||||
embed = model(text_input.input_ids, attention_mask=text_input.attention_mask)
|
||||
embeds.append(embed)
|
||||
|
||||
embeds = torch.cat(embeds, dim=0)
|
||||
|
||||
return embeds
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def contrast_evaluation(text_embeds, code_embeds, img2txt):
|
||||
score_matrix_i2t = text_embeds @ code_embeds.t()
|
||||
scores_i2t = score_matrix_i2t.cpu().numpy()
|
||||
|
||||
ranks = np.ones(scores_i2t.shape[0]) * -1
|
||||
for index, score in enumerate(scores_i2t):
|
||||
inds = np.argsort(score)[::-1]
|
||||
ranks[index] = np.where(inds == img2txt[index])[0][0]
|
||||
|
||||
# Compute metrics
|
||||
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
||||
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
||||
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
||||
mrr = 100.0 * np.mean(1 / (ranks + 1))
|
||||
|
||||
eval_result = {'r1': tr1,
|
||||
'r5': tr5,
|
||||
'r10': tr10,
|
||||
'mrr': mrr}
|
||||
return eval_result
|
||||
|
||||
|
||||
def main(args):
|
||||
print("\nCreating retrieval dataset")
|
||||
_, _, test_dataset, code_dataset = create_dataset(args.data_dir, args.lang)
|
||||
|
||||
test_loader, code_loader = create_loader([test_dataset, code_dataset], [None, None],
|
||||
batch_size=[args.batch_size, args.batch_size],
|
||||
num_workers=[4, 4], is_trains=[False, False], collate_fns=[None, None])
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(args.model_name, trust_remote_code=True)
|
||||
print(f'Loaded {args.model_name} model (#para={model.num_parameters()})')
|
||||
|
||||
print('\nStart zero-shot evaluation...')
|
||||
device = torch.device(args.device)
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
text_embeds = get_feats(model, tokenizer, test_loader, args.max_text_len, device, desc='Get text feats')
|
||||
code_embeds = get_feats(model, tokenizer, code_loader, args.max_code_len, device, desc='Get code feats')
|
||||
test_result = contrast_evaluation(text_embeds, code_embeds, test_loader.dataset.text2code)
|
||||
print(f'\n====> zero-shot test result: ', test_result)
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
log_stats = {
|
||||
**{f'test_{k}': v for k, v in test_result.items()},
|
||||
'epoch': -1,
|
||||
}
|
||||
|
||||
with open(os.path.join(args.output_dir, "result.txt"), "a") as f:
|
||||
f.write(json.dumps(log_stats) + "\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--lang', type=str,
|
||||
choices=['ruby', 'javascript', 'go', 'python', 'java', 'php', 'AdvTest', 'cosqa'])
|
||||
parser.add_argument('--model_name', type=str, default='Salesforce/codet5p-110m-embedding')
|
||||
parser.add_argument('--data_dir', type=str)
|
||||
parser.add_argument('--output_dir', type=str)
|
||||
parser.add_argument('--batch_size', default=256, type=int)
|
||||
parser.add_argument('--max_text_len', default=64, type=int)
|
||||
parser.add_argument('--max_code_len', default=360, type=int)
|
||||
parser.add_argument('--device', default='cuda')
|
||||
parser.add_argument('--local_rank', default=-1, type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
argsdict = vars(args)
|
||||
if args.local_rank in [0, -1]:
|
||||
print(pprint.pformat(argsdict))
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
with open(os.path.join(args.output_dir, "command.txt"), 'w') as f:
|
||||
f.write(pprint.pformat(argsdict))
|
||||
|
||||
main(args)
|
16
CodeT5+/code_retrieval/run_retrieval.sh
Normal file
16
CodeT5+/code_retrieval/run_retrieval.sh
Normal file
@ -0,0 +1,16 @@
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
# choices: ruby javascript go python java php AdvTest cosqa
|
||||
LANG=ruby
|
||||
BS=256
|
||||
CODE_LEN=360
|
||||
TEXT_LEN=64
|
||||
MODEL_NAME=Salesforce/codet5p-110m-embedding
|
||||
DATA_DIR=/path/to/data
|
||||
|
||||
TRG_DIR=saved_models/${LANG}/codet5p_110m_embedding_TL${TEXT_LEN}_CL${CODE_LEN}
|
||||
mkdir -p $TRG_DIR
|
||||
echo 'Target dir: '$TRG_DIR
|
||||
|
||||
python eval_contrast_retrieval.py --model_name $MODEL_NAME --lang $LANG --output_dir $TRG_DIR \
|
||||
--data_dir $DATA_DIR --max_text_len $TEXT_LEN --max_code_len $CODE_LEN --batch_size $BS \
|
||||
2>&1 | tee ${TRG_DIR}/log.txt
|
Loading…
Reference in New Issue
Block a user