diff --git a/modules/text_generation.py b/modules/text_generation.py index 17ec6bd8..79067f84 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -56,8 +56,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap if generate_func != generate_reply_HF and shared.args.verbose: logger.info("PROMPT=") - print(question) - print() + print_prompt(question) # Prepare the input original_question = question @@ -343,8 +342,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings print() logger.info("PROMPT=") - print(decode(input_ids[0], skip_special_tokens=False)) - print() + print_prompt(decode(input_ids[0], skip_special_tokens=False)) # Handle StreamingLLM for llamacpp_HF if shared.model.__class__.__name__ == 'LlamacppHF' and shared.args.streaming_llm: @@ -433,3 +431,18 @@ def generate_reply_custom(question, original_question, seed, state, stopping_str new_tokens = len(encode(original_question + reply)[0]) - original_tokens print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') return + + +def print_prompt(prompt, max_chars=2000): + DARK_YELLOW = "\033[38;5;3m" + RESET = "\033[0m" + + if len(prompt) > max_chars: + half_chars = max_chars // 2 + hidden_len = len(prompt[half_chars:-half_chars]) + hidden_msg = f"{DARK_YELLOW}[...{hidden_len} characters hidden...]{RESET}" + print(prompt[:half_chars] + hidden_msg + prompt[-half_chars:]) + else: + print(prompt) + + print()