Faster llamacpp_HF prefix matching

This commit is contained in:
oobabooga 2023-09-18 11:02:45 -07:00
parent 893a72a1c5
commit 745807dc03

View File

@ -123,12 +123,12 @@ class LlamacppHF(PreTrainedModel):
# https://github.com/abetlen/llama-cpp-python/commit/f4090a0bb2a2a25acfe28d31c82cc1aa273bedee
if labels is None:
if past_seq is not None:
longest_prefix = 0
for i in range(min(past_seq.shape[0], seq_tensor.shape[0])):
if past_seq[i] == seq_tensor[i]:
longest_prefix += 1
else:
break
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
if len(indices) > 0:
longest_prefix = indices[0].item()
else:
longest_prefix = min_length
if longest_prefix > 0:
self.model.n_tokens = longest_prefix