From d890c99b53343e2f5f08407b2b160fe44e094917 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 14 Mar 2024 09:18:54 -0700 Subject: [PATCH] Fix StreamingLLM when content is removed from the beginning of the prompt --- modules/cache_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/modules/cache_utils.py b/modules/cache_utils.py index c235d0ca..0d1368a2 100644 --- a/modules/cache_utils.py +++ b/modules/cache_utils.py @@ -19,12 +19,12 @@ def process_llamacpp_cache(model, new_sequence, past_sequence): past_sequence = torch.tensor(past_sequence) prefix_length = find_prefix_length(past_sequence[:i1], new_sequence[:j1]) - sink_length = prefix_length - if sink_length < shared.args.attention_sink_size: - sink_length = shared.args.attention_sink_size - + sink_length = max(prefix_length, shared.args.attention_sink_size) removed_length = i1 - sink_length + if removed_length <= 0: + return past_sequence.tolist() + matching_prefix = past_sequence[:prefix_length] removed_chunk = past_sequence[sink_length:i1] overlapping_sequence = new_sequence[j1:j2 + 1] @@ -37,10 +37,11 @@ def process_llamacpp_cache(model, new_sequence, past_sequence): print('MATCHING PREFIX=', repr(shared.tokenizer.decode(matching_prefix))) print('ADDED CHUNK=', repr(shared.tokenizer.decode(added_chunk))) print('REMOVED CHUNK=', repr(shared.tokenizer.decode(removed_chunk))) + print('REMOVED LENGTH=', removed_length) print() # Remove interval [sink_length, sink_length + removed_length) from the context - # Subtract removed_length from model.n_tokens + # Update model.n_tokens model._ctx.kv_cache_seq_rm(0, sink_length, sink_length + removed_length) model._ctx.kv_cache_seq_shift(0, sink_length + removed_length, -1, -removed_length)