diff --git a/modules/text_generation.py b/modules/text_generation.py index 9a908df3..295b0dc5 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -1,6 +1,7 @@ import ast import copy import html +import pprint import random import re import time @@ -65,7 +66,8 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap all_stop_strings += st if shared.args.verbose: - print(f'\n\n{question}\n--------------------\n') + logger.info("PROMPT=") + print(question) shared.stop_everything = False clear_torch_cache() @@ -283,7 +285,7 @@ def get_reply_from_output_ids(output_ids, state, starting_from=0): def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False): generate_params = {} - for k in ['max_new_tokens', 'do_sample', 'temperature', 'temperature_last', 'dynatemp', 'top_p', 'min_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']: + for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynatemp', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']: generate_params[k] = state[k] if state['negative_prompt'] != '': @@ -342,6 +344,11 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings apply_extensions('logits_processor', processor, input_ids) generate_params['logits_processor'] = processor + if shared.args.verbose: + logger.info("GENERATE_PARAMS=") + pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(generate_params) + print() + t0 = time.time() try: if not is_chat and not shared.is_seq2seq: