Cubic sampling w/ curve param (#5551)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
kalomaze 2024-03-03 10:22:21 -06:00 committed by GitHub
parent 3168644152
commit cfb25c9b3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 26 additions and 12 deletions

View File

@ -13,6 +13,7 @@ class GenerationOptions(BaseModel):
dynatemp_high: float = 1
dynatemp_exponent: float = 1
smoothing_factor: float = 0
smoothing_curve: float = 1
top_k: int = 0
repetition_penalty: float = 1
repetition_penalty_range: int = 1024

View File

@ -164,6 +164,7 @@ def transformers_samplers():
'dynatemp_high',
'dynatemp_exponent',
'smoothing_factor',
'smoothing_curve',
'top_p',
'min_p',
'top_k',
@ -240,6 +241,7 @@ loaders_samplers = {
'dynatemp_high',
'dynatemp_exponent',
'smoothing_factor',
'smoothing_curve',
'top_p',
'min_p',
'top_k',
@ -298,6 +300,7 @@ loaders_samplers = {
'dynatemp_high',
'dynatemp_exponent',
'smoothing_factor',
'smoothing_curve',
'top_p',
'min_p',
'top_k',

View File

@ -19,6 +19,7 @@ def default_preset():
'dynatemp_high': 1,
'dynatemp_exponent': 1,
'smoothing_factor': 0,
'smoothing_curve': 1,
'top_p': 1,
'min_p': 0,
'top_k': 0,
@ -109,7 +110,7 @@ def random_preset(state):
[1, 2],
[1, 5]
],
'smoothing_factor': [0.2, 0.3, 0.6, 1.2]
'smoothing_factor': [0.2, 0.3, 0.6, 1.2],
},
'repetition': {
'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25],

View File

@ -99,22 +99,27 @@ class DynamicTemperatureLogitsWarper(LogitsWarper):
class QuadraticSamplingLogitsWarper(LogitsWarper):
'''
Quadratic sampling.
Quadratic sampling with smoothing factor and smoothing curve parameters.
'''
def __init__(self, smoothing_factor: float):
def __init__(self, smoothing_factor, smoothing_curve):
self.smoothing_factor = smoothing_factor
self.smoothing_curve = smoothing_curve
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Compute the maximum logit value
# Compute necessary values
max_logit = scores.max()
diff = scores - max_logit
k = (3 - self.smoothing_curve) / 2
s = (self.smoothing_curve - 1) / 2
# Apply the quadratic transformation
transformed_logits = -(self.smoothing_factor * (scores - max_logit)**2) + max_logit
# No need to print the top 5 logits since this is not required
# print("Original top 5 logits: ", torch.topk(scores, 5))
# print("New top 5 logits: ", torch.topk(transformed_logits, 5))
# Apply transformation to non-negative infinity values
transformed_logits = torch.where(
scores != float('-inf'),
-(k * self.smoothing_factor * diff**2) + (s * self.smoothing_factor * diff**3) + max_logit,
scores
)
return transformed_logits
@ -367,7 +372,8 @@ def get_logits_warper_patch(self, generation_config):
if generation_config.smoothing_factor > 0:
warpers_to_add.append(
QuadraticSamplingLogitsWarper(
smoothing_factor=generation_config.smoothing_factor
smoothing_factor=generation_config.smoothing_factor,
smoothing_curve=generation_config.smoothing_curve
)
)
@ -468,6 +474,7 @@ def generation_config_init_patch(self, **kwargs):
self.dynatemp_high = kwargs.pop("dynatemp_high", 1)
self.dynatemp_exponent = kwargs.pop("dynatemp_exponent", 1)
self.smoothing_factor = kwargs.pop("smoothing_factor", 0.0)
self.smoothing_curve = kwargs.pop("smoothing_curve", 1.0)
self.tfs = kwargs.pop("tfs", 1.0)
self.top_a = kwargs.pop("top_a", 0.0)
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)

View File

@ -286,7 +286,7 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']:
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']:
if k in state:
generate_params[k] = state[k]

View File

@ -123,6 +123,7 @@ def list_interface_input_elements():
'dynatemp_high',
'dynatemp_exponent',
'smoothing_factor',
'smoothing_curve',
'top_p',
'min_p',
'top_k',

View File

@ -50,6 +50,7 @@ def create_ui(default_preset):
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
shared.gradio['smoothing_factor'] = gr.Slider(0.0, 10.0, value=generate_params['smoothing_factor'], step=0.01, label='smoothing_factor', info='Activates Quadratic Sampling.')
shared.gradio['smoothing_curve'] = gr.Slider(1.0, 10.0, value=generate_params['smoothing_curve'], step=0.01, label='smoothing_curve', info='Adjusts the dropoff curve of Quadratic Sampling.')
shared.gradio['dynamic_temperature'] = gr.Checkbox(value=generate_params['dynamic_temperature'], label='dynamic_temperature')
shared.gradio['dynatemp_low'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_low'], step=0.01, label='dynatemp_low', visible=generate_params['dynamic_temperature'])
shared.gradio['dynatemp_high'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_high'], step=0.01, label='dynatemp_high', visible=generate_params['dynamic_temperature'])