From cfb25c9b3f0dc1521ff00036364b1235ef50e8e6 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sun, 3 Mar 2024 10:22:21 -0600 Subject: [PATCH] Cubic sampling w/ curve param (#5551) --------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com> --- extensions/openai/typing.py | 1 + modules/loaders.py | 3 +++ modules/presets.py | 3 ++- modules/sampler_hijack.py | 27 +++++++++++++++++---------- modules/text_generation.py | 2 +- modules/ui.py | 1 + modules/ui_parameters.py | 1 + 7 files changed, 26 insertions(+), 12 deletions(-) diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index ec351167..3ae02e68 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -13,6 +13,7 @@ class GenerationOptions(BaseModel): dynatemp_high: float = 1 dynatemp_exponent: float = 1 smoothing_factor: float = 0 + smoothing_curve: float = 1 top_k: int = 0 repetition_penalty: float = 1 repetition_penalty_range: int = 1024 diff --git a/modules/loaders.py b/modules/loaders.py index 08a7f229..513fd910 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -164,6 +164,7 @@ def transformers_samplers(): 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', + 'smoothing_curve', 'top_p', 'min_p', 'top_k', @@ -240,6 +241,7 @@ loaders_samplers = { 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', + 'smoothing_curve', 'top_p', 'min_p', 'top_k', @@ -298,6 +300,7 @@ loaders_samplers = { 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', + 'smoothing_curve', 'top_p', 'min_p', 'top_k', diff --git a/modules/presets.py b/modules/presets.py index d4fcc7d0..7a041311 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -19,6 +19,7 @@ def default_preset(): 'dynatemp_high': 1, 'dynatemp_exponent': 1, 'smoothing_factor': 0, + 'smoothing_curve': 1, 'top_p': 1, 'min_p': 0, 'top_k': 0, @@ -109,7 +110,7 @@ def random_preset(state): [1, 2], [1, 5] ], - 'smoothing_factor': [0.2, 0.3, 0.6, 1.2] + 'smoothing_factor': [0.2, 0.3, 0.6, 1.2], }, 'repetition': { 'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25], diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index 057052e9..da52f4d0 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -99,22 +99,27 @@ class DynamicTemperatureLogitsWarper(LogitsWarper): class QuadraticSamplingLogitsWarper(LogitsWarper): ''' - Quadratic sampling. + Quadratic sampling with smoothing factor and smoothing curve parameters. ''' - def __init__(self, smoothing_factor: float): + def __init__(self, smoothing_factor, smoothing_curve): self.smoothing_factor = smoothing_factor + self.smoothing_curve = smoothing_curve def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - # Compute the maximum logit value + + # Compute necessary values max_logit = scores.max() + diff = scores - max_logit + k = (3 - self.smoothing_curve) / 2 + s = (self.smoothing_curve - 1) / 2 - # Apply the quadratic transformation - transformed_logits = -(self.smoothing_factor * (scores - max_logit)**2) + max_logit - - # No need to print the top 5 logits since this is not required - # print("Original top 5 logits: ", torch.topk(scores, 5)) - # print("New top 5 logits: ", torch.topk(transformed_logits, 5)) + # Apply transformation to non-negative infinity values + transformed_logits = torch.where( + scores != float('-inf'), + -(k * self.smoothing_factor * diff**2) + (s * self.smoothing_factor * diff**3) + max_logit, + scores + ) return transformed_logits @@ -367,7 +372,8 @@ def get_logits_warper_patch(self, generation_config): if generation_config.smoothing_factor > 0: warpers_to_add.append( QuadraticSamplingLogitsWarper( - smoothing_factor=generation_config.smoothing_factor + smoothing_factor=generation_config.smoothing_factor, + smoothing_curve=generation_config.smoothing_curve ) ) @@ -468,6 +474,7 @@ def generation_config_init_patch(self, **kwargs): self.dynatemp_high = kwargs.pop("dynatemp_high", 1) self.dynatemp_exponent = kwargs.pop("dynatemp_exponent", 1) self.smoothing_factor = kwargs.pop("smoothing_factor", 0.0) + self.smoothing_curve = kwargs.pop("smoothing_curve", 1.0) self.tfs = kwargs.pop("tfs", 1.0) self.top_a = kwargs.pop("top_a", 0.0) self.mirostat_mode = kwargs.pop("mirostat_mode", 0) diff --git a/modules/text_generation.py b/modules/text_generation.py index c62b9b01..227d1822 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -286,7 +286,7 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0): def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False): generate_params = {} - for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']: + for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']: if k in state: generate_params[k] = state[k] diff --git a/modules/ui.py b/modules/ui.py index bb5a3339..6249bb48 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -123,6 +123,7 @@ def list_interface_input_elements(): 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', + 'smoothing_curve', 'top_p', 'min_p', 'top_k', diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index 078590dc..7aebe672 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -50,6 +50,7 @@ def create_ui(default_preset): shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau') shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta') shared.gradio['smoothing_factor'] = gr.Slider(0.0, 10.0, value=generate_params['smoothing_factor'], step=0.01, label='smoothing_factor', info='Activates Quadratic Sampling.') + shared.gradio['smoothing_curve'] = gr.Slider(1.0, 10.0, value=generate_params['smoothing_curve'], step=0.01, label='smoothing_curve', info='Adjusts the dropoff curve of Quadratic Sampling.') shared.gradio['dynamic_temperature'] = gr.Checkbox(value=generate_params['dynamic_temperature'], label='dynamic_temperature') shared.gradio['dynatemp_low'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_low'], step=0.01, label='dynatemp_low', visible=generate_params['dynamic_temperature']) shared.gradio['dynatemp_high'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_high'], step=0.01, label='dynatemp_high', visible=generate_params['dynamic_temperature'])