diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py index 92de102c..00da92ed 100644 --- a/modules/llamacpp_hf.py +++ b/modules/llamacpp_hf.py @@ -123,12 +123,12 @@ class LlamacppHF(PreTrainedModel): # https://github.com/abetlen/llama-cpp-python/commit/f4090a0bb2a2a25acfe28d31c82cc1aa273bedee if labels is None: if past_seq is not None: - longest_prefix = 0 - for i in range(min(past_seq.shape[0], seq_tensor.shape[0])): - if past_seq[i] == seq_tensor[i]: - longest_prefix += 1 - else: - break + min_length = min(past_seq.shape[0], seq_tensor.shape[0]) + indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) + if len(indices) > 0: + longest_prefix = indices[0].item() + else: + longest_prefix = min_length if longest_prefix > 0: self.model.n_tokens = longest_prefix