diff --git a/modules/cache_utils.py b/modules/cache_utils.py new file mode 100644 index 00000000..3a200d8e --- /dev/null +++ b/modules/cache_utils.py @@ -0,0 +1,108 @@ +import torch + +from modules import shared +from modules.logging_colors import logger + + +def process_llamacpp_cache(model, new_sequence, past_sequence): + i1, i2, j1, j2 = find_longest_common_substring_indices(past_sequence, new_sequence) + overlap_length = i2 - i1 + 1 + + # Do StreamingLLM if i1 > 0 (ie the longest common subsequence is not a prefix) + # and the overlap length is sufficiently long. + if i1 > 0 and overlap_length > 0.2 * len(new_sequence): + + new_sequence = torch.tensor(new_sequence) + past_sequence = torch.tensor(past_sequence) + + prefix_length = find_prefix_length(past_sequence[:i1], new_sequence[:j1]) + sink_length = prefix_length + if sink_length < shared.args.attention_sink_size: + sink_length = shared.args.attention_sink_size + + removed_length = i1 - sink_length + + matching_prefix = past_sequence[:prefix_length] + removed_chunk = past_sequence[sink_length:i1] + overlapping_sequence = new_sequence[j1:j2 + 1] + added_chunk = new_sequence[j2 + 1:] + + # print(past_sequence) + # print(new_sequence) + + print() + print('MATCHING PREFIX=', repr(shared.tokenizer.decode(matching_prefix))) + print('ADDED CHUNK=', repr(shared.tokenizer.decode(added_chunk))) + print('REMOVED CHUNK=', repr(shared.tokenizer.decode(removed_chunk))) + print() + + # Remove interval [sink_length, sink_length + removed_length) from the context + # Subtract removed_length from model.n_tokens + model._ctx.kv_cache_seq_rm(0, sink_length, sink_length + removed_length) + model._ctx.kv_cache_seq_shift(0, sink_length + removed_length, -1, -removed_length) + + new_sequence = new_sequence.tolist() + model.input_ids[:j2 + 1] = new_sequence[:j2 + 1] + model.n_tokens = j2 + 1 + + return new_sequence[:j2 + 1] + else: + return past_sequence + + +def find_prefix_length(past_seq, seq_tensor): + ''' + Given two torch tensors, finds the length of the longest + common prefix between the two. + ''' + min_length = min(past_seq.shape[0], seq_tensor.shape[0]) + indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) + if len(indices) > 0: + prefix_length = indices[0].item() + else: + prefix_length = min_length + + return prefix_length + + +def find_longest_common_substring_indices(list1, list2): + ''' + Given two lists, solves the Longest Common Substring problem. + + It returns the indices where the substring starts and ends in + s1 and s2. + + Example: + + ir, jr, ir2, jr2 = find_longest_common_substring_indices(s1, s2) + print(s1[ir:jr + 1]) + print(s2[ir2:jr2 + 1]) + + Adapted from + https://rosettacode.org/wiki/Longest_common_substring#Python + ''' + + len_list1, len_list2 = len(list1), len(list2) + start_index_list1, end_index_list1 = 0, -1 + start_index_list2, end_index_list2 = 0, -1 + + for index1 in range(len_list1): + try: + index2 = list2.index(list1[index1]) + except ValueError: + continue + while index2 >= 0: + temp_index1, temp_index2 = index1, index2 + while temp_index1 < len_list1 and temp_index2 < len_list2 and list2[temp_index2] == list1[temp_index1]: + if temp_index1 - index1 >= end_index_list1 - start_index_list1: + start_index_list1, end_index_list1 = index1, temp_index1 + start_index_list2, end_index_list2 = index2, temp_index2 + + temp_index1 += 1 + temp_index2 += 1 + try: + index2 = list2.index(list1[index1], index2 + 1) + except ValueError: + break + + return start_index_list1, end_index_list1, start_index_list2, end_index_list2 diff --git a/modules/llama_cpp_python_hijack.py b/modules/llama_cpp_python_hijack.py index 9bb38512..96de839e 100644 --- a/modules/llama_cpp_python_hijack.py +++ b/modules/llama_cpp_python_hijack.py @@ -2,6 +2,9 @@ from typing import Sequence from tqdm import tqdm +from modules import shared +from modules.cache_utils import process_llamacpp_cache + try: import llama_cpp except: @@ -58,6 +61,25 @@ def eval_with_progress(self, tokens: Sequence[int]): self.n_tokens += n_tokens +def monkey_patch_generate(lib): + + def my_generate(self, *args, **kwargs): + + if shared.args.streaming_llm: + new_sequence = args[0] + past_sequence = self._input_ids + + # Do the cache trimming for StreamingLLM + process_llamacpp_cache(self, new_sequence, past_sequence) + + for output in self.original_generate(*args, **kwargs): + yield output + + lib.Llama.original_generate = lib.Llama.generate + lib.Llama.generate = my_generate + + for lib in [llama_cpp, llama_cpp_cuda, llama_cpp_cuda_tensorcores]: if lib is not None: lib.Llama.eval = eval_with_progress + monkey_patch_generate(lib) diff --git a/modules/loaders.py b/modules/loaders.py index 330f2903..f1c44a90 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -46,6 +46,8 @@ loaders_and_params = OrderedDict({ 'no_offload_kqv', 'row_split', 'tensorcores', + 'streaming_llm', + 'attention_sink_size', ], 'llamacpp_HF': [ 'n_ctx', @@ -69,6 +71,8 @@ loaders_and_params = OrderedDict({ 'no_offload_kqv', 'row_split', 'tensorcores', + 'streaming_llm', + 'attention_sink_size', 'llamacpp_HF_info', ], 'ExLlamav2_HF': [ diff --git a/modules/shared.py b/modules/shared.py index 10a70001..8758cee1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -130,6 +130,8 @@ group.add_argument('--logits_all', action='store_true', help='Needs to be set fo group.add_argument('--no_offload_kqv', action='store_true', help='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.') group.add_argument('--cache-capacity', type=str, help='Maximum cache capacity (llama-cpp-python). Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed.') group.add_argument('--row_split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.') +group.add_argument('--streaming-llm', action='store_true', help='Activates StreamingLLM, which prevents the prompt from ever being reevaluated when old chat messages are removed due to the context length for the model being reached.') +group.add_argument('--attention-sink-size', type=int, default=5, help='Minimum attention sink length from StreamingLLM.') # ExLlamaV2 group = parser.add_argument_group('ExLlamaV2') diff --git a/modules/text_generation.py b/modules/text_generation.py index 227d1822..dc9c63ea 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -13,6 +13,7 @@ import transformers from transformers import LogitsProcessorList, is_torch_xpu_available import modules.shared as shared +from modules.cache_utils import process_llamacpp_cache from modules.callbacks import ( Iteratorize, Stream, @@ -364,6 +365,12 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings print(decode(input_ids[0], skip_special_tokens=False)) print() + # Handle StreamingLLM for llamacpp_HF + if shared.model.__class__.__name__ == 'LlamacppHF' and shared.args.streaming_llm: + tmp = process_llamacpp_cache(shared.model.model, input_ids[-1].tolist(), shared.model.model._input_ids) + shared.model.past_seq = torch.tensor(tmp) + shared.model.save_cache() + t0 = time.time() try: if not is_chat and not shared.is_seq2seq: diff --git a/modules/ui.py b/modules/ui.py index 6e1b12b0..4a03f843 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -97,6 +97,8 @@ def list_model_elements(): 'no_offload_kqv', 'row_split', 'tensorcores', + 'streaming_llm', + 'attention_sink_size', 'hqq_backend', ] if is_torch_xpu_available(): diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index c29db7d0..e3b0e883 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -117,6 +117,8 @@ def create_ui(): shared.gradio['use_flash_attention_2'] = gr.Checkbox(label="use_flash_attention_2", value=shared.args.use_flash_attention_2, info='Set use_flash_attention_2=True while loading the model.') shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices) shared.gradio['tensorcores'] = gr.Checkbox(label="tensorcores", value=shared.args.tensorcores, info='NVIDIA only: use llama-cpp-python compiled with tensor cores support. This increases performance on RTX cards.') + shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming_llm", value=shared.args.streaming_llm, info='(experimental) Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.') + shared.gradio['attention_sink_size'] = gr.Number(label="attention_sink_size", value=shared.args.attention_sink_size) shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='llama.cpp: Use llama-cpp-python compiled without GPU acceleration. Transformers: use PyTorch in CPU mode.') shared.gradio['row_split'] = gr.Checkbox(label="row_split", value=shared.args.row_split, info='Split the model by rows across GPUs. This may improve multi-gpu performance.') shared.gradio['no_offload_kqv'] = gr.Checkbox(label="no_offload_kqv", value=shared.args.no_offload_kqv, info='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.')