Optimize llamacpp_hf a bit

This commit is contained in:
oobabooga 2023-07-16 20:49:48 -07:00
parent 6a3edb0542
commit a199f21799

View File

@ -42,7 +42,6 @@ class LlamacppHF(PreTrainedModel):
# Make the forward call
seq_tensor = torch.tensor(seq)
self.cache = seq_tensor
if labels is None:
if self.cache is None or not torch.equal(self.cache, seq_tensor[:-1]):
self.model.reset()
@ -50,13 +49,15 @@ class LlamacppHF(PreTrainedModel):
else:
self.model.eval([seq[-1]])
logits = torch.tensor(self.model.eval_logits)[-1].view(1, 1, -1).to(kwargs['input_ids'].device)
logits = torch.tensor(self.model.eval_logits[-1]).view(1, 1, -1).to(kwargs['input_ids'].device)
else:
self.model.reset()
self.model.eval(seq)
logits = torch.tensor(self.model.eval_logits)
logits = logits.view(1, logits.shape[0], logits.shape[1]).to(kwargs['input_ids'].device)
self.cache = seq_tensor
# Based on transformers/models/llama/modeling_llama.py
loss = None
if labels is not None: