mirror of
https://github.com/salesforce/CodeT5.git
synced 2024-10-01 06:35:38 -04:00
add humaneval evaluation
This commit is contained in:
parent
aaeb477b89
commit
ce36447d85
@ -1,11 +1,13 @@
|
|||||||
# CodeT5+
|
# CodeT5+
|
||||||
|
|
||||||
Official research release for the **CodeT5+** models (`220M`, `770M`, `2B`, `6B` `16B`) for a wide range of **Code Understanding and Generation** tasks.
|
Official research release for the **CodeT5+** models (`220M`, `770M`, `2B`, `6B` `16B`) for a wide range of **Code Understanding and Generation** tasks.
|
||||||
|
Find out more via our [blog post](https://blog.salesforceairesearch.com/codet5-open-code-large-language-models/).
|
||||||
|
|
||||||
*Title*: [CodeT5+: Open Code Large Language Models for Code Understanding and Generation](https://arxiv.org/pdf/2305.07922.pdf)
|
*Title*: [CodeT5+: Open Code Large Language Models for Code Understanding and Generation](https://arxiv.org/pdf/2305.07922.pdf)
|
||||||
|
|
||||||
*Authors*: [Yue Wang](https://yuewang-cuhk.github.io/)\*, [Hung Le](https://sites.google.com/view/henryle2018/home?pli=1)\*, [Akhilesh Deepak Gotmare](https://akhileshgotmare.github.io/), [Nghi D.Q. Bui](https://bdqnghi.github.io/), [Junnan Li](https://sites.google.com/site/junnanlics), [Steven C.H. Hoi](https://sites.google.com/view/stevenhoi/home) (* indicates equal contribution)
|
*Authors*: [Yue Wang](https://yuewang-cuhk.github.io/)\*, [Hung Le](https://sites.google.com/view/henryle2018/home?pli=1)\*, [Akhilesh Deepak Gotmare](https://akhileshgotmare.github.io/), [Nghi D.Q. Bui](https://bdqnghi.github.io/), [Junnan Li](https://sites.google.com/site/junnanlics), [Steven C.H. Hoi](https://sites.google.com/view/stevenhoi/home) (* indicates equal contribution)
|
||||||
|
|
||||||
|
|
||||||
# What is this about?
|
# What is this about?
|
||||||
CodeT5+ is a new family of open code large language models with an encoder-decoder architecture that can flexibly operate in different modes (i.e. _encoder-only_, _decoder-only_, and _encoder-decoder_) to support a wide range of code understanding and generation tasks.
|
CodeT5+ is a new family of open code large language models with an encoder-decoder architecture that can flexibly operate in different modes (i.e. _encoder-only_, _decoder-only_, and _encoder-decoder_) to support a wide range of code understanding and generation tasks.
|
||||||
|
|
||||||
@ -32,7 +34,8 @@ We release the following CodeT5+ models at Huggingface:
|
|||||||
# How to Use?
|
# How to Use?
|
||||||
All CodeT5+ models and tokenizers can be easily loaded using the `AutoModelForSeq2SeqLM` and `AutoTokenizer` functionality.
|
All CodeT5+ models and tokenizers can be easily loaded using the `AutoModelForSeq2SeqLM` and `AutoTokenizer` functionality.
|
||||||
For tokenizers, CodeT5+ `220M` and `770M` employ the same tokenizer as the original [CodeT5](https://github.com/salesforce/CodeT5) while CodeT5+ `2B`, `6B`, `16B` employ the same tokenizer as [CodeGen]( https://github.com/salesforce/CodeGen).
|
For tokenizers, CodeT5+ `220M` and `770M` employ the same tokenizer as the original [CodeT5](https://github.com/salesforce/CodeT5) while CodeT5+ `2B`, `6B`, `16B` employ the same tokenizer as [CodeGen]( https://github.com/salesforce/CodeGen).
|
||||||
To load CodeT5+ `2B`, `6B`, `16B`, please set `trust_remote_code=True` as the [model class](https://huggingface.co/Salesforce/codet5p-16b/blob/main/modeling_codet5p.py) is defined in the Huggingface repo.
|
To load CodeT5+ `2B`, `6B`, `16B`, and InstructCodeT5+ `16B`, please set `trust_remote_code=True` as the [model class](https://huggingface.co/Salesforce/codet5p-16b/blob/main/modeling_codet5p.py) is defined in the Huggingface repo.
|
||||||
|
Besides, these models would benefit from passing additional prompts to the decoder via `decoder_input_ids` to achieve better generation performance.
|
||||||
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@ -48,17 +51,28 @@ model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint,
|
|||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
trust_remote_code=True).to(device)
|
trust_remote_code=True).to(device)
|
||||||
|
|
||||||
inputs = tokenizer.encode("def print_hello():", return_tensors="pt").to(device)
|
encoding = tokenizer("def print_hello_world():", return_tensors="pt").to(device)
|
||||||
outputs = model.generate(inputs, max_length=12)
|
encoding['decoder_input_ids'] = encoding['input_ids'].clone()
|
||||||
|
outputs = model.generate(**encoding, max_length=15)
|
||||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
# Reproduce the Results
|
# Reproduce the Results
|
||||||
|
|
||||||
## HumanEval
|
## HumanEval
|
||||||
|
|
||||||
TBA
|
### Installation
|
||||||
|
* Install the official HumanEval evaluation tool released by OpenAI following the instructions in ihis [repo](https://github.com/openai/human-eval).
|
||||||
|
* Install the Pytorch (version `1.13.1`) and transformers (version `4.21.3`) libraries.
|
||||||
|
|
||||||
|
### Generating programs from CodeT5+ models
|
||||||
|
`cd humaneval` then run the inference via `bash run_generate.sh`.
|
||||||
|
You can select the model to generate from by changing the `model` variable in the script.
|
||||||
|
Following the original setting in the HumanEval paper, we generate 200 programs (`pred_num=200`) for each problem and employs nucleus sampling with different temperature `T` for computing `pass@k` (`T=0.2,0.6,0.8` for `k=1,10,100` respectively).
|
||||||
|
The generated programs will be saved in `preds/${model}_T${temp}_N${pred_num}`.
|
||||||
|
|
||||||
|
### Evaluating pass@k
|
||||||
|
`cd humaneval` then run the evaluation via `bash run_eval.sh`.
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
162
CodeT5+/humaneval/generate_codet5p.py
Normal file
162
CodeT5+/humaneval/generate_codet5p.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
import argparse
|
||||||
|
import pprint
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
|
from human_eval.data import write_jsonl, read_problems, stream_jsonl
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text(prompt, remove_lines=True):
|
||||||
|
token = '\"\"\"'
|
||||||
|
start = token
|
||||||
|
end = '>>>'
|
||||||
|
|
||||||
|
start_idx = prompt.find(start) + len(start)
|
||||||
|
end_idx = prompt.find(end)
|
||||||
|
|
||||||
|
output = prompt[start_idx: end_idx]
|
||||||
|
if remove_lines:
|
||||||
|
output = output.replace('\n', ' ')
|
||||||
|
output = re.sub(r"\s+", " ", output).strip()
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
INSTRUCTION = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||||
|
|
||||||
|
|
||||||
|
### Instruction:
|
||||||
|
Create a Python script for this problem:
|
||||||
|
{}
|
||||||
|
|
||||||
|
### Response:"""
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument('--model', type=str, default='Salesforce/instructcodet5p-16b', help="")
|
||||||
|
parser.add_argument('--output_path', type=str, help="")
|
||||||
|
parser.add_argument('--start_index', type=int, default=0, help="")
|
||||||
|
parser.add_argument('--end_index', type=int, default=164, help="")
|
||||||
|
parser.add_argument('--temperature', type=float, default=0.8, help="")
|
||||||
|
parser.add_argument('--N', type=int, default=200, help="")
|
||||||
|
parser.add_argument('--max_len', type=int, default=600, help="")
|
||||||
|
parser.add_argument('--decoding_style', type=str, default='sampling', help="")
|
||||||
|
parser.add_argument('--num_seqs_per_iter', type=int, default=50, help='')
|
||||||
|
parser.add_argument('--overwrite', action='store_true', help='')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
argsdict = vars(args)
|
||||||
|
print(pprint.pformat(argsdict))
|
||||||
|
|
||||||
|
STOP_SEQS = ['\nclass', '\ndef', '\n#', '\nif', '\nprint']
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
problems = read_problems()
|
||||||
|
|
||||||
|
task_ids = sorted(problems.keys())[args.start_index: args.end_index]
|
||||||
|
prompts = [problems[task_id]['prompt'] for task_id in task_ids]
|
||||||
|
num_samples = len(prompts)
|
||||||
|
print("Number of samples: {}".format(num_samples))
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||||
|
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(args.model,
|
||||||
|
trust_remote_code=True, # False for 220m and 770m models
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
low_cpu_mem_usage=True)
|
||||||
|
model.eval()
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
# for larger LLMs such as 2B, 6B, and 16B, we need to pass the text prompt to the decoder
|
||||||
|
prompt_to_decoder = True if any([size in args.model for size in ['2b', '6b', '16b']]) else False
|
||||||
|
|
||||||
|
print(f"Loaded {args.model}.")
|
||||||
|
for i in tqdm(range(num_samples), ncols=0, total=num_samples):
|
||||||
|
output_file = args.output_path + '/{}.jsonl'.format(args.start_index + i)
|
||||||
|
|
||||||
|
if os.path.exists(output_file) and not args.overwrite:
|
||||||
|
print(f'Skip {output_file} as it already exists')
|
||||||
|
continue
|
||||||
|
|
||||||
|
prompt = prompts[i].replace(' ', '\t')
|
||||||
|
if args.model == 'Salesforce/instructcodet5p-16b':
|
||||||
|
prompt_batch = [INSTRUCTION.format(extract_text(prompt))]
|
||||||
|
prompt_batch_decoder = [INSTRUCTION.format(extract_text(prompt)) + prompt]
|
||||||
|
else:
|
||||||
|
prompt_batch = [prompt]
|
||||||
|
prompt_batch_decoder = [prompt]
|
||||||
|
|
||||||
|
ids_batch = [task_ids[i]]
|
||||||
|
|
||||||
|
completion_seqs = []
|
||||||
|
|
||||||
|
encoding = tokenizer(prompt_batch, return_tensors="pt", truncation=True, max_length=args.max_len).to(device)
|
||||||
|
encoding_decoder = tokenizer(prompt_batch_decoder, return_tensors="pt", truncation=True,
|
||||||
|
max_length=args.max_len).to(device)
|
||||||
|
|
||||||
|
if args.decoding_style == 'sampling':
|
||||||
|
loops = int(args.N / args.num_seqs_per_iter)
|
||||||
|
else:
|
||||||
|
loops = 1
|
||||||
|
|
||||||
|
for _ in tqdm(range(loops), total=loops, leave=False, ncols=0):
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if args.decoding_style == 'sampling':
|
||||||
|
if prompt_to_decoder:
|
||||||
|
gen_tokens = model.generate(**encoding,
|
||||||
|
decoder_input_ids=encoding_decoder['input_ids'],
|
||||||
|
do_sample=True,
|
||||||
|
temperature=args.temperature,
|
||||||
|
max_length=args.max_len,
|
||||||
|
num_return_sequences=args.num_seqs_per_iter,
|
||||||
|
decoder_start_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
top_p=0.95)
|
||||||
|
else:
|
||||||
|
gen_tokens = model.generate(**encoding,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=args.temperature,
|
||||||
|
max_length=args.max_len,
|
||||||
|
num_return_sequences=args.num_seqs_per_iter,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
top_p=0.95)
|
||||||
|
|
||||||
|
if gen_tokens is not None:
|
||||||
|
if prompt_to_decoder:
|
||||||
|
gen_tokens = gen_tokens[:, encoding_decoder['input_ids'].shape[-1]:]
|
||||||
|
gen_seqs = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
|
||||||
|
else:
|
||||||
|
gen_seqs = None
|
||||||
|
|
||||||
|
if gen_seqs is not None:
|
||||||
|
assert len(ids_batch) == 1
|
||||||
|
task_id = ids_batch[0]
|
||||||
|
|
||||||
|
for seq_idx, gen_seq in enumerate(gen_seqs):
|
||||||
|
completion_seq = gen_seq
|
||||||
|
for stop_seq in STOP_SEQS:
|
||||||
|
index = completion_seq.find(stop_seq)
|
||||||
|
if index != -1:
|
||||||
|
completion_seq = completion_seq[:index]
|
||||||
|
completion_seq = completion_seq.replace('\t', ' ')
|
||||||
|
all_code = prompt.replace('\t', ' ') + completion_seq
|
||||||
|
|
||||||
|
completion_seqs.append(
|
||||||
|
{'task_id': task_id,
|
||||||
|
'completion': completion_seq,
|
||||||
|
'all_code': all_code # final code for evaluation with unit tests
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Saving results to {}".format(output_file))
|
||||||
|
write_jsonl(output_file, completion_seqs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
47
CodeT5+/humaneval/process_preds.py
Normal file
47
CodeT5+/humaneval/process_preds.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from human_eval.data import read_problems, write_jsonl, stream_jsonl
|
||||||
|
import glob
|
||||||
|
from tqdm import tqdm
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
parser.add_argument(
|
||||||
|
'--path',
|
||||||
|
type=str,
|
||||||
|
help="")
|
||||||
|
parser.add_argument(
|
||||||
|
'--out_path',
|
||||||
|
type=str,
|
||||||
|
help="")
|
||||||
|
parser.add_argument(
|
||||||
|
'--add_prompt',
|
||||||
|
action='store_true',
|
||||||
|
help='')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
files = sorted(glob.glob(args.path + '/*.jsonl'))
|
||||||
|
print("{} files in {}".format(len(files), args.path))
|
||||||
|
|
||||||
|
problems = read_problems('data/HumanEval.jsonl.gz')
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for code_file in tqdm(files, total=len(files)):
|
||||||
|
codes = [c for c in stream_jsonl(code_file)]
|
||||||
|
if args.add_prompt:
|
||||||
|
for code in codes:
|
||||||
|
task_id = code['task_id']
|
||||||
|
prompt = problems[task_id]['prompt']
|
||||||
|
if 'def' in code['completion']:
|
||||||
|
def_line = code['completion'].index('def')
|
||||||
|
completion = code['completion'][def_line:]
|
||||||
|
next_line = completion.index('\n')
|
||||||
|
completion = code['completion'][def_line+next_line+1:]
|
||||||
|
code['all_code'] = prompt + completion
|
||||||
|
|
||||||
|
output += codes
|
||||||
|
|
||||||
|
print("save to {}".format(args.out_path))
|
||||||
|
write_jsonl(args.out_path, output)
|
29
CodeT5+/humaneval/run_generate.sh
Normal file
29
CodeT5+/humaneval/run_generate.sh
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
model=instructcodet5p-16b
|
||||||
|
temp=0.2
|
||||||
|
max_len=800
|
||||||
|
pred_num=200
|
||||||
|
num_seqs_per_iter=2 # 25 for 350M and 770M, 10 for 2B, 8 for 6B, 2 for 16B on A100-40G
|
||||||
|
|
||||||
|
output_path=preds/${model}_T${temp}_N${pred_num}
|
||||||
|
|
||||||
|
mkdir -p ${output_path}
|
||||||
|
echo 'Output path: '$output_path
|
||||||
|
echo 'Model to eval: '$model
|
||||||
|
|
||||||
|
# 164 problems, 21 per GPU if GPU=8
|
||||||
|
index=0
|
||||||
|
gpu_num=8
|
||||||
|
for ((i = 0; i < $gpu_num; i++)); do
|
||||||
|
start_index=$((i * 21))
|
||||||
|
end_index=$(((i + 1) * 21))
|
||||||
|
|
||||||
|
gpu=$((i))
|
||||||
|
echo 'Running process #' ${i} 'from' $start_index 'to' $end_index 'on GPU' ${gpu}
|
||||||
|
((index++))
|
||||||
|
(
|
||||||
|
CUDA_VISIBLE_DEVICES=$gpu python generate_codet5p.py --model Salesforce/${model} \
|
||||||
|
--start_index ${start_index} --end_index ${end_index} --temperature ${temp} \
|
||||||
|
--num_seqs_per_iter ${num_seqs_per_iter} --N ${pred_num} --max_len ${max_len} --output_path ${output_path}
|
||||||
|
) &
|
||||||
|
if (($index % $gpu_num == 0)); then wait; fi
|
||||||
|
done
|
6
CodeT5+/humaneval/test_eval.sh
Normal file
6
CodeT5+/humaneval/test_eval.sh
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
output_path=preds/instructcodet5p-16b_T0.2_N200
|
||||||
|
|
||||||
|
echo 'Output path: '$output_path
|
||||||
|
python process_preds.py --path ${output_path} --out_path ${output_path}.jsonl
|
||||||
|
|
||||||
|
evaluate_functional_correctness ${output_path}.jsonl
|
@ -197,7 +197,23 @@ Note that we employ one A100 GPU for all fine-tuning experiments.
|
|||||||
### How to fine-tune on your own task and dataset?
|
### How to fine-tune on your own task and dataset?
|
||||||
If you want to fine-tune on your dataset, you can add your own task and sub_task in `configs.py` ([here](https://github.com/salesforce/CodeT5/blob/d27512d23ba6130e089e571d8c3e399760db1c31/configs.py#L11)) and add your data path and the function to read in `utils.py` ([here](https://github.com/salesforce/CodeT5/blob/5bb41e21b07fee73f310476a91ded00e385290d7/utils.py#L103) and [here](https://github.com/salesforce/CodeT5/blob/5bb41e21b07fee73f310476a91ded00e385290d7/utils.py#L149)). The read function can be implemented in `_utils.py` similar to [this one](https://github.com/salesforce/CodeT5/blob/aaf9c4a920c4986abfd54a74f5456b056b6409e0/_utils.py#L213). If your task to add is a generation task, you can simply reuse or customize the `run_gen.py`. For understanding tasks, please refer to `run_defect.py` and `run_clone.py`.
|
If you want to fine-tune on your dataset, you can add your own task and sub_task in `configs.py` ([here](https://github.com/salesforce/CodeT5/blob/d27512d23ba6130e089e571d8c3e399760db1c31/configs.py#L11)) and add your data path and the function to read in `utils.py` ([here](https://github.com/salesforce/CodeT5/blob/5bb41e21b07fee73f310476a91ded00e385290d7/utils.py#L103) and [here](https://github.com/salesforce/CodeT5/blob/5bb41e21b07fee73f310476a91ded00e385290d7/utils.py#L149)). The read function can be implemented in `_utils.py` similar to [this one](https://github.com/salesforce/CodeT5/blob/aaf9c4a920c4986abfd54a74f5456b056b6409e0/_utils.py#L213). If your task to add is a generation task, you can simply reuse or customize the `run_gen.py`. For understanding tasks, please refer to `run_defect.py` and `run_clone.py`.
|
||||||
|
|
||||||
## Get Involved
|
|
||||||
|
|
||||||
Please create a GitHub issue if you have any questions, suggestions, requests or bug-reports. We welcome PRs!
|
## Citation
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{
|
||||||
|
wang2021codet5,
|
||||||
|
title={CodeT5: Identifier-aware Unified Pre-trained Encoder-Decoder Models for Code Understanding and Generation},
|
||||||
|
author={Yue Wang, Weishi Wang, Shafiq Joty, Steven C.H. Hoi},
|
||||||
|
booktitle={EMNLP},
|
||||||
|
year={2021},
|
||||||
|
}
|
||||||
|
|
||||||
|
@inproceedings{
|
||||||
|
le2022coderl,
|
||||||
|
title={CodeRL: Mastering Code Generation through Pretrained Models and Deep Reinforcement Learning},
|
||||||
|
author={Le, Hung and Wang, Yue and Gotmare, Akhilesh Deepak and Savarese, Silvio and Hoi, Steven C. H.},
|
||||||
|
booktitle={NeurIPS},
|
||||||
|
year={2022}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
@ -26,7 +26,7 @@ At Salesforce, we build an AI coding assistant demo using CodeT5 as a VS Code pl
|
|||||||
**May 2023**
|
**May 2023**
|
||||||
|
|
||||||
**CodeT5+** paper and models are released!🔥 <br>
|
**CodeT5+** paper and models are released!🔥 <br>
|
||||||
[paper](https://arxiv.org/pdf/2305.07922.pdf) | [code](https://github.com/salesforce/CodeT5/tree/main/CodeT5+) | [model](https://huggingface.co/models?sort=downloads&search=codet5p)
|
[paper](https://arxiv.org/pdf/2305.07922.pdf) | [code](https://github.com/salesforce/CodeT5/tree/main/CodeT5+) | [model](https://huggingface.co/models?sort=downloads&search=codet5p) | [blog](https://blog.salesforceairesearch.com/codet5-open-code-large-language-models/)
|
||||||
|
|
||||||
**Sep 2022**
|
**Sep 2022**
|
||||||
|
|
||||||
@ -56,7 +56,6 @@ multilingual code summarization.
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
If you find this code to be useful for your research, please consider citing:
|
If you find this code to be useful for your research, please consider citing:
|
||||||
@ -74,7 +73,7 @@ If you find this code to be useful for your research, please consider citing:
|
|||||||
le2022coderl,
|
le2022coderl,
|
||||||
title={CodeRL: Mastering Code Generation through Pretrained Models and Deep Reinforcement Learning},
|
title={CodeRL: Mastering Code Generation through Pretrained Models and Deep Reinforcement Learning},
|
||||||
author={Le, Hung and Wang, Yue and Gotmare, Akhilesh Deepak and Savarese, Silvio and Hoi, Steven C. H.},
|
author={Le, Hung and Wang, Yue and Gotmare, Akhilesh Deepak and Savarese, Silvio and Hoi, Steven C. H.},
|
||||||
journal={NeurIPS},
|
booktitle={NeurIPS},
|
||||||
year={2022}
|
year={2022}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user