From aa5d671579eebc502ada5e1c251240801ffe0fa8 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 4 Nov 2023 13:09:07 -0300 Subject: [PATCH] Add temperature_last parameter (#4472) --- extensions/api/util.py | 1 + modules/loaders.py | 7 +++++++ modules/presets.py | 1 + modules/sampler_hijack.py | 32 ++++++++++++++++++++++++++------ modules/text_generation.py | 2 +- modules/ui.py | 1 + modules/ui_parameters.py | 1 + 7 files changed, 38 insertions(+), 7 deletions(-) diff --git a/extensions/api/util.py b/extensions/api/util.py index 4fb2720a..206f2597 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -25,6 +25,7 @@ def build_parameters(body, chat=False): 'max_tokens_second': int(body.get('max_tokens_second', 0)), 'do_sample': bool(body.get('do_sample', True)), 'temperature': float(body.get('temperature', 0.5)), + 'temperature_last': bool(body.get('temperature_last', False)), 'top_p': float(body.get('top_p', 1)), 'min_p': float(body.get('min_p', 0)), 'typical_p': float(body.get('typical_p', body.get('typical', 1))), diff --git a/modules/loaders.py b/modules/loaders.py index 6b09f6af..c9accc34 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -148,6 +148,7 @@ loaders_and_params = OrderedDict({ loaders_samplers = { 'Transformers': { 'temperature', + 'temperature_last', 'top_p', 'min_p', 'top_k', @@ -184,6 +185,7 @@ loaders_samplers = { }, 'ExLlama_HF': { 'temperature', + 'temperature_last', 'top_p', 'min_p', 'top_k', @@ -245,6 +247,7 @@ loaders_samplers = { }, 'ExLlamav2_HF': { 'temperature', + 'temperature_last', 'top_p', 'min_p', 'top_k', @@ -277,6 +280,7 @@ loaders_samplers = { }, 'AutoGPTQ': { 'temperature', + 'temperature_last', 'top_p', 'min_p', 'top_k', @@ -313,6 +317,7 @@ loaders_samplers = { }, 'GPTQ-for-LLaMa': { 'temperature', + 'temperature_last', 'top_p', 'min_p', 'top_k', @@ -365,6 +370,7 @@ loaders_samplers = { }, 'llamacpp_HF': { 'temperature', + 'temperature_last', 'top_p', 'min_p', 'top_k', @@ -404,6 +410,7 @@ loaders_samplers = { }, 'AutoAWQ': { 'temperature', + 'temperature_last', 'top_p', 'min_p', 'top_k', diff --git a/modules/presets.py b/modules/presets.py index 27bd2cba..62c2c90c 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -8,6 +8,7 @@ def default_preset(): return { 'do_sample': True, 'temperature': 1, + 'temperature_last': False, 'top_p': 1, 'min_p': 0, 'top_k': 0, diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index a4fee744..218d1b11 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -12,6 +12,7 @@ from transformers.generation.logits_process import ( global_scores = None + class MinPLogitsWarper(LogitsWarper): def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): if min_p < 0 or min_p > 1.0: @@ -41,6 +42,7 @@ class MinPLogitsWarper(LogitsWarper): scores = scores.masked_fill(indices_to_remove, self.filter_value) return scores + class TailFreeLogitsWarper(LogitsWarper): def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): tfs = float(tfs) @@ -214,19 +216,36 @@ def get_logits_warper_patch(self, generation_config): if not isinstance(warper, TemperatureLogitsWarper): warpers.remove(warper) else: - if generation_config.tfs is not None and 0.0 <= generation_config.tfs <= 1.0: + if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0: warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep)) - if generation_config.top_a is not None and 0.0 <= generation_config.top_a <= 1.0: + if generation_config.top_a is not None and 0.0 < generation_config.top_a <= 1.0: warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep)) - if generation_config.min_p is not None and 0.0 <= generation_config.min_p <= 1.0: + if generation_config.min_p is not None and 0.0 < generation_config.min_p <= 1.0: warpers_to_add.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep)) - if warpers and isinstance(warpers[-1], LogitNormalization): - warpers = warpers[:-1] + warpers_to_add + [warpers[-1]] + if len(warpers) > 0 and isinstance(warpers[-1], LogitNormalization): + normalize = warpers.pop(-1) else: - warpers += warpers_to_add + normalize = None + + warpers += warpers_to_add + if generation_config.temperature_last: + temperature_idx = None + for i in range(len(warpers)): + if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper': + temperature_idx = i + break + + if temperature_idx is not None: + warpers = warpers[:temperature_idx] + warpers[temperature_idx + 1:] + [warpers[temperature_idx]] + warpers = LogitsProcessorList(warpers) + + if normalize is not None: + warpers.append(normalize) warpers.append(SpyLogitsWarper()) + # for i in range(len(warpers)): + # print(warpers[i].__class__.__name__) return warpers @@ -261,6 +280,7 @@ def generation_config_init_patch(self, **kwargs): self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0) self.presence_penalty = kwargs.pop("presence_penalty", 0) self.frequency_penalty = kwargs.pop("frequency_penalty", 0) + self.temperature_last = kwargs.pop("temperature_last", False) def hijack_samplers(): diff --git a/modules/text_generation.py b/modules/text_generation.py index 61d3d0f4..e2efa41d 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -274,7 +274,7 @@ def apply_stopping_strings(reply, all_stop_strings): def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False): generate_params = {} - for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'min_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']: + for k in ['max_new_tokens', 'do_sample', 'temperature', 'temperature_last', 'top_p', 'min_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']: generate_params[k] = state[k] if state['negative_prompt'] != '': diff --git a/modules/ui.py b/modules/ui.py index bc689fe0..466af187 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -104,6 +104,7 @@ def list_interface_input_elements(): 'max_tokens_second', 'seed', 'temperature', + 'temperature_last', 'top_p', 'min_p', 'top_k', diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index 97ea18ed..fa245c4d 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -48,6 +48,7 @@ def create_ui(default_preset): shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.') shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau') shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta') + shared.gradio['temperature_last'] = gr.Checkbox(value=generate_params['temperature_last'], label='temperature_last', info='Makes temperature the last sampler instead of the first.') shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample') shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)') with gr.Accordion('Other parameters', open=False):