From cb26163a209d6272ed14da83782f71bae4681d75 Mon Sep 17 00:00:00 2001 From: tdrussell <6509934+tdrussell@users.noreply.github.com> Date: Thu, 5 Oct 2023 10:20:56 -0500 Subject: [PATCH] Fix off-by-one error in exllama_hf caching logic (#4145) --- modules/exllama_hf.py | 4 ++++ modules/exllamav2_hf.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/modules/exllama_hf.py b/modules/exllama_hf.py index 3245ac87..3ba1f3c3 100644 --- a/modules/exllama_hf.py +++ b/modules/exllama_hf.py @@ -94,6 +94,10 @@ class ExllamaHF(PreTrainedModel): ex_cache.current_seq_len = longest_prefix if len(seq_tensor) - longest_prefix > 1: self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, lora=self.lora) + elif len(seq_tensor) == longest_prefix: + # Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one, + # because we feed input_ids[-1] to forward() below, but that last token is already in the cache! + ex_cache.current_seq_len -= 1 if reset: ex_cache.current_seq_len = 0 diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index 6542ede9..71cf513f 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -98,6 +98,10 @@ class Exllamav2HF(PreTrainedModel): ex_cache.current_seq_len = longest_prefix if len(seq_tensor) - longest_prefix > 1: self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True) + elif len(seq_tensor) == longest_prefix: + # Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one, + # because we feed input_ids[-1] to forward() below, but that last token is already in the cache! + ex_cache.current_seq_len -= 1 if reset: ex_cache.current_seq_len = 0