Fix StreamingLLM when content is removed from the beginning of the prompt

This commit is contained in:
oobabooga 2024-03-14 09:18:54 -07:00
parent d828844a6f
commit d890c99b53

View File

@ -19,12 +19,12 @@ def process_llamacpp_cache(model, new_sequence, past_sequence):
past_sequence = torch.tensor(past_sequence) past_sequence = torch.tensor(past_sequence)
prefix_length = find_prefix_length(past_sequence[:i1], new_sequence[:j1]) prefix_length = find_prefix_length(past_sequence[:i1], new_sequence[:j1])
sink_length = prefix_length sink_length = max(prefix_length, shared.args.attention_sink_size)
if sink_length < shared.args.attention_sink_size:
sink_length = shared.args.attention_sink_size
removed_length = i1 - sink_length removed_length = i1 - sink_length
if removed_length <= 0:
return past_sequence.tolist()
matching_prefix = past_sequence[:prefix_length] matching_prefix = past_sequence[:prefix_length]
removed_chunk = past_sequence[sink_length:i1] removed_chunk = past_sequence[sink_length:i1]
overlapping_sequence = new_sequence[j1:j2 + 1] overlapping_sequence = new_sequence[j1:j2 + 1]
@ -37,10 +37,11 @@ def process_llamacpp_cache(model, new_sequence, past_sequence):
print('MATCHING PREFIX=', repr(shared.tokenizer.decode(matching_prefix))) print('MATCHING PREFIX=', repr(shared.tokenizer.decode(matching_prefix)))
print('ADDED CHUNK=', repr(shared.tokenizer.decode(added_chunk))) print('ADDED CHUNK=', repr(shared.tokenizer.decode(added_chunk)))
print('REMOVED CHUNK=', repr(shared.tokenizer.decode(removed_chunk))) print('REMOVED CHUNK=', repr(shared.tokenizer.decode(removed_chunk)))
print('REMOVED LENGTH=', removed_length)
print() print()
# Remove interval [sink_length, sink_length + removed_length) from the context # Remove interval [sink_length, sink_length + removed_length) from the context
# Subtract removed_length from model.n_tokens # Update model.n_tokens
model._ctx.kv_cache_seq_rm(0, sink_length, sink_length + removed_length) model._ctx.kv_cache_seq_rm(0, sink_length, sink_length + removed_length)
model._ctx.kv_cache_seq_shift(0, sink_length + removed_length, -1, -removed_length) model._ctx.kv_cache_seq_shift(0, sink_length + removed_length, -1, -removed_length)