CodeT5/models.py

399 lines
17 KiB
Python

import torch
import torch.nn as nn
import numpy as np
from transformers import (RobertaConfig, RobertaModel, RobertaTokenizer,
BartConfig, BartForConditionalGeneration, BartTokenizer,
T5Config, T5ForConditionalGeneration, T5Tokenizer)
import logging
logger = logging.getLogger(__name__)
MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer),
't5': (T5Config, T5ForConditionalGeneration, T5Tokenizer),
'codet5': (T5Config, T5ForConditionalGeneration, RobertaTokenizer),
'bart': (BartConfig, BartForConditionalGeneration, BartTokenizer)}
def get_model_size(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
model_size = sum([np.prod(p.size()) for p in model_parameters])
return "{}M".format(round(model_size / 1e+6))
def build_or_load_gen_model(args):
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name)
if args.model_type == 'roberta':
encoder = model_class.from_pretrained(args.model_name_or_path, config=config)
decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads)
decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
model = Seq2Seq(encoder=encoder, decoder=decoder, config=config,
beam_size=args.beam_size, max_length=args.max_target_length,
sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)
else:
model = model_class.from_pretrained(args.model_name_or_path)
logger.info("Finish loading model [%s] from %s", get_model_size(model), args.model_name_or_path)
if args.load_model_path is not None:
logger.info("Reload model from {}".format(args.load_model_path))
model.load_state_dict(torch.load(args.load_model_path))
return config, model, tokenizer
class RobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size * 2, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, 2)
def forward(self, x, **kwargs):
x = x.reshape(-1, x.size(-1) * 2)
x = self.dense(x)
x = torch.tanh(x)
x = self.out_proj(x)
return x
class CloneModel(nn.Module):
def __init__(self, encoder, config, tokenizer, args):
super(CloneModel, self).__init__()
self.encoder = encoder
self.config = config
self.tokenizer = tokenizer
self.classifier = RobertaClassificationHead(config)
self.args = args
def get_t5_vec(self, source_ids):
attention_mask = source_ids.ne(self.tokenizer.pad_token_id)
outputs = self.encoder(input_ids=source_ids, attention_mask=attention_mask,
labels=source_ids, decoder_attention_mask=attention_mask, output_hidden_states=True)
hidden_states = outputs['decoder_hidden_states'][-1]
eos_mask = source_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
vec = hidden_states[eos_mask, :].view(hidden_states.size(0), -1,
hidden_states.size(-1))[:, -1, :]
return vec
def get_bart_vec(self, source_ids):
attention_mask = source_ids.ne(self.tokenizer.pad_token_id)
outputs = self.encoder(input_ids=source_ids, attention_mask=attention_mask,
labels=source_ids, decoder_attention_mask=attention_mask, output_hidden_states=True)
hidden_states = outputs['decoder_hidden_states'][-1]
eos_mask = source_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
vec = hidden_states[eos_mask, :].view(hidden_states.size(0), -1,
hidden_states.size(-1))[:, -1, :]
return vec
def get_roberta_vec(self, source_ids):
attention_mask = source_ids.ne(self.tokenizer.pad_token_id)
vec = self.encoder(input_ids=source_ids, attention_mask=attention_mask)[0][:, 0, :]
return vec
def forward(self, source_ids=None, labels=None):
source_ids = source_ids.view(-1, self.args.max_source_length)
if self.args.model_type == 'codet5':
vec = self.get_t5_vec(source_ids)
elif self.args.model_type == 'bart':
vec = self.get_bart_vec(source_ids)
elif self.args.model_type == 'roberta':
vec = self.get_roberta_vec(source_ids)
logits = self.classifier(vec)
prob = nn.functional.softmax(logits)
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels)
return loss, prob
else:
return prob
class DefectModel(nn.Module):
def __init__(self, encoder, config, tokenizer, args):
super(DefectModel, self).__init__()
self.encoder = encoder
self.config = config
self.tokenizer = tokenizer
self.classifier = nn.Linear(config.hidden_size, 2)
self.args = args
def get_t5_vec(self, source_ids):
attention_mask = source_ids.ne(self.tokenizer.pad_token_id)
outputs = self.encoder(input_ids=source_ids, attention_mask=attention_mask,
labels=source_ids, decoder_attention_mask=attention_mask, output_hidden_states=True)
hidden_states = outputs['decoder_hidden_states'][-1]
eos_mask = source_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
vec = hidden_states[eos_mask, :].view(hidden_states.size(0), -1,
hidden_states.size(-1))[:, -1, :]
return vec
def get_bart_vec(self, source_ids):
attention_mask = source_ids.ne(self.tokenizer.pad_token_id)
outputs = self.encoder(input_ids=source_ids, attention_mask=attention_mask,
labels=source_ids, decoder_attention_mask=attention_mask, output_hidden_states=True)
hidden_states = outputs['decoder_hidden_states'][-1]
eos_mask = source_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
vec = hidden_states[eos_mask, :].view(hidden_states.size(0), -1,
hidden_states.size(-1))[:, -1, :]
return vec
def get_roberta_vec(self, source_ids):
attention_mask = source_ids.ne(self.tokenizer.pad_token_id)
vec = self.encoder(input_ids=source_ids, attention_mask=attention_mask)[0][:, 0, :]
return vec
def forward(self, source_ids=None, labels=None):
source_ids = source_ids.view(-1, self.args.max_source_length)
if self.args.model_type == 'codet5':
vec = self.get_t5_vec(source_ids)
elif self.args.model_type == 'bart':
vec = self.get_bart_vec(source_ids)
elif self.args.model_type == 'roberta':
vec = self.get_roberta_vec(source_ids)
logits = self.classifier(vec)
prob = nn.functional.softmax(logits)
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels)
return loss, prob
else:
return prob
# https://github.com/microsoft/CodeBERT/blob/master/CodeBERT/code2nl/model.py
class Seq2Seq(nn.Module):
"""
Build Seqence-to-Sequence.
Parameters:
* `encoder`- encoder of seq2seq model. e.g. roberta
* `decoder`- decoder of seq2seq model. e.g. transformer
* `config`- configuration of encoder model.
* `beam_size`- beam size for beam search.
* `max_length`- max length of target for beam search.
* `sos_id`- start of symbol ids in target for beam search.
* `eos_id`- end of symbol ids in target for beam search.
"""
def __init__(self, encoder, decoder, config, beam_size=None, max_length=None, sos_id=None, eos_id=None):
super(Seq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.config = config
self.register_buffer("bias", torch.tril(torch.ones(2048, 2048)))
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.lsm = nn.LogSoftmax(dim=-1)
self.tie_weights()
self.beam_size = beam_size
self.max_length = max_length
self.sos_id = sos_id
self.eos_id = eos_id
def _tie_or_clone_weights(self, first_module, second_module):
""" Tie or clone module weights depending of weither we are using TorchScript or not
"""
if self.config.torchscript:
first_module.weight = nn.Parameter(second_module.weight.clone())
else:
first_module.weight = second_module.weight
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.lm_head,
self.encoder.embeddings.word_embeddings)
def forward(self, source_ids=None, source_mask=None, target_ids=None, target_mask=None, args=None):
outputs = self.encoder(source_ids, attention_mask=source_mask)
encoder_output = outputs[0].permute([1, 0, 2]).contiguous()
if target_ids is not None:
attn_mask = -1e4 * (1 - self.bias[:target_ids.shape[1], :target_ids.shape[1]])
tgt_embeddings = self.encoder.embeddings(target_ids).permute([1, 0, 2]).contiguous()
out = self.decoder(tgt_embeddings, encoder_output, tgt_mask=attn_mask,
memory_key_padding_mask=~source_mask)
# memory_key_padding_mask=(1 - source_mask).bool())
hidden_states = torch.tanh(self.dense(out)).permute([1, 0, 2]).contiguous()
lm_logits = self.lm_head(hidden_states)
# Shift so that tokens < n predict n
active_loss = target_mask[..., 1:].ne(0).view(-1) == 1
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = target_ids[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))[active_loss],
shift_labels.view(-1)[active_loss])
outputs = loss, loss * active_loss.sum(), active_loss.sum()
return outputs
else:
# Predict
preds = []
zero = torch.cuda.LongTensor(1).fill_(0)
for i in range(source_ids.shape[0]):
context = encoder_output[:, i:i + 1]
context_mask = source_mask[i:i + 1, :]
beam = Beam(self.beam_size, self.sos_id, self.eos_id)
input_ids = beam.getCurrentState()
context = context.repeat(1, self.beam_size, 1)
context_mask = context_mask.repeat(self.beam_size, 1)
for _ in range(self.max_length):
if beam.done():
break
attn_mask = -1e4 * (1 - self.bias[:input_ids.shape[1], :input_ids.shape[1]])
tgt_embeddings = self.encoder.embeddings(input_ids).permute([1, 0, 2]).contiguous()
out = self.decoder(tgt_embeddings, context, tgt_mask=attn_mask,
memory_key_padding_mask=~context_mask)
# memory_key_padding_mask=(1 - context_mask).bool())
out = torch.tanh(self.dense(out))
hidden_states = out.permute([1, 0, 2]).contiguous()[:, -1, :]
out = self.lsm(self.lm_head(hidden_states)).data
beam.advance(out)
input_ids.data.copy_(input_ids.data.index_select(0, beam.getCurrentOrigin()))
input_ids = torch.cat((input_ids, beam.getCurrentState()), -1)
hyp = beam.getHyp(beam.getFinal())
pred = beam.buildTargetTokens(hyp)[:self.beam_size]
pred = [torch.cat([x.view(-1) for x in p] + [zero] * (self.max_length - len(p))).view(1, -1) for p in
pred]
preds.append(torch.cat(pred, 0).unsqueeze(0))
preds = torch.cat(preds, 0)
return preds
class Beam(object):
def __init__(self, size, sos, eos):
self.size = size
self.tt = torch.cuda
# The score for each translation on the beam.
self.scores = self.tt.FloatTensor(size).zero_()
# The backpointers at each time-step.
self.prevKs = []
# The outputs at each time-step.
self.nextYs = [self.tt.LongTensor(size)
.fill_(0)]
self.nextYs[0][0] = sos
# Has EOS topped the beam yet.
self._eos = eos
self.eosTop = False
# Time and k pair for finished.
self.finished = []
def getCurrentState(self):
"Get the outputs for the current timestep."
batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1)
return batch
def getCurrentOrigin(self):
"Get the backpointers for the current timestep."
return self.prevKs[-1]
def advance(self, wordLk):
"""
Given prob over words for every last beam `wordLk` and attention
`attnOut`: Compute and update the beam search.
Parameters:
* `wordLk`- probs of advancing from the last step (K x words)
* `attnOut`- attention at the last step
Returns: True if beam search is complete.
"""
numWords = wordLk.size(1)
# Sum the previous scores.
if len(self.prevKs) > 0:
beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
# Don't let EOS have children.
for i in range(self.nextYs[-1].size(0)):
if self.nextYs[-1][i] == self._eos:
beamLk[i] = -1e20
else:
beamLk = wordLk[0]
flatBeamLk = beamLk.view(-1)
bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)
self.scores = bestScores
# bestScoresId is flattened beam x word array, so calculate which
# word and beam each score came from
prevK = bestScoresId // numWords
self.prevKs.append(prevK)
self.nextYs.append((bestScoresId - prevK * numWords))
for i in range(self.nextYs[-1].size(0)):
if self.nextYs[-1][i] == self._eos:
s = self.scores[i]
self.finished.append((s, len(self.nextYs) - 1, i))
# End condition is when top-of-beam is EOS and no global score.
if self.nextYs[-1][0] == self._eos:
self.eosTop = True
def done(self):
return self.eosTop and len(self.finished) >= self.size
def getFinal(self):
if len(self.finished) == 0:
self.finished.append((self.scores[0], len(self.nextYs) - 1, 0))
self.finished.sort(key=lambda a: -a[0])
if len(self.finished) != self.size:
unfinished = []
for i in range(self.nextYs[-1].size(0)):
if self.nextYs[-1][i] != self._eos:
s = self.scores[i]
unfinished.append((s, len(self.nextYs) - 1, i))
unfinished.sort(key=lambda a: -a[0])
self.finished += unfinished[:self.size - len(self.finished)]
return self.finished[:self.size]
def getHyp(self, beam_res):
"""
Walk back to construct the full hypothesis.
"""
hyps = []
for _, timestep, k in beam_res:
hyp = []
for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):
hyp.append(self.nextYs[j + 1][k])
k = self.prevKs[j][k]
hyps.append(hyp[::-1])
return hyps
def buildTargetTokens(self, preds):
sentence = []
for pred in preds:
tokens = []
for tok in pred:
if tok == self._eos:
break
tokens.append(tok)
sentence.append(tokens)
return sentence