diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index 53177f4f..e5401378 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -6,6 +6,7 @@ import torch from modules import shared from modules.callbacks import Iteratorize from modules.logging_colors import logger +from modules.text_generation import get_max_prompt_length import llama_cpp @@ -91,6 +92,12 @@ class LlamaCppModel: LogitsProcessorList = llama_cpp_lib().LogitsProcessorList prompt = prompt if type(prompt) is str else prompt.decode() + + # Handle truncation + prompt = self.encode(prompt) + prompt = prompt[-get_max_prompt_length(state):] + prompt = self.decode(prompt).decode('utf-8') + completion_chunks = self.model.create_completion( prompt=prompt, max_tokens=state['max_new_tokens'], diff --git a/modules/text_generation.py b/modules/text_generation.py index f6f71990..7507a731 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -39,7 +39,6 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel']: input_ids = shared.tokenizer.encode(str(prompt)) input_ids = np.array(input_ids).reshape(1, len(input_ids)) - return input_ids else: input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)