mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Quadratic sampling (#5403)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
e98d1086f5
commit
b6077b02e4
@ -55,6 +55,7 @@ For more information about the parameters, the [transformers documentation](http
|
|||||||
* **mirostat_tau**: No idea, see the paper for details. According to the Preset Arena, 8 is a good value.
|
* **mirostat_tau**: No idea, see the paper for details. According to the Preset Arena, 8 is a good value.
|
||||||
* **mirostat_eta**: No idea, see the paper for details. According to the Preset Arena, 0.1 is a good value.
|
* **mirostat_eta**: No idea, see the paper for details. According to the Preset Arena, 0.1 is a good value.
|
||||||
* **dynamic_temperature**: Activates Dynamic Temperature. This modifies temperature to range between "dynatemp_low" (minimum) and "dynatemp_high" (maximum), with an entropy-based scaling. The steepness of the curve is controlled by "dynatemp_exponent".
|
* **dynamic_temperature**: Activates Dynamic Temperature. This modifies temperature to range between "dynatemp_low" (minimum) and "dynatemp_high" (maximum), with an entropy-based scaling. The steepness of the curve is controlled by "dynatemp_exponent".
|
||||||
|
* **smoothing_factor**: Activates Quadratic Sampling. This takes precedence over regular temperature and dynamic temperature, and replaces those samplers. When `0 < smoothing_factor < 1`, the logits distribution becomes flatter. When `smoothing_factor > 1`, it becomes more peaked.
|
||||||
* **temperature_last**: Makes temperature the last sampler instead of the first. With this, you can remove low probability tokens with a sampler like min_p and then use a high temperature to make the model creative without losing coherency.
|
* **temperature_last**: Makes temperature the last sampler instead of the first. With this, you can remove low probability tokens with a sampler like min_p and then use a high temperature to make the model creative without losing coherency.
|
||||||
* **do_sample**: When unchecked, sampling is entirely disabled, and greedy decoding is used instead (the most likely token is always picked).
|
* **do_sample**: When unchecked, sampling is entirely disabled, and greedy decoding is used instead (the most likely token is always picked).
|
||||||
* **Seed**: Set the Pytorch seed to this number. Note that some loaders do not use Pytorch (notably llama.cpp), and others are not deterministic (notably ExLlama v1 and v2). For these loaders, the seed has no effect.
|
* **Seed**: Set the Pytorch seed to this number. Note that some loaders do not use Pytorch (notably llama.cpp), and others are not deterministic (notably ExLlama v1 and v2). For these loaders, the seed has no effect.
|
||||||
|
@ -12,6 +12,7 @@ class GenerationOptions(BaseModel):
|
|||||||
dynatemp_low: float = 1
|
dynatemp_low: float = 1
|
||||||
dynatemp_high: float = 1
|
dynatemp_high: float = 1
|
||||||
dynatemp_exponent: float = 1
|
dynatemp_exponent: float = 1
|
||||||
|
smoothing_factor: float = 0
|
||||||
top_k: int = 0
|
top_k: int = 0
|
||||||
repetition_penalty: float = 1
|
repetition_penalty: float = 1
|
||||||
repetition_penalty_range: int = 1024
|
repetition_penalty_range: int = 1024
|
||||||
|
@ -159,6 +159,7 @@ def transformers_samplers():
|
|||||||
'dynatemp_low',
|
'dynatemp_low',
|
||||||
'dynatemp_high',
|
'dynatemp_high',
|
||||||
'dynatemp_exponent',
|
'dynatemp_exponent',
|
||||||
|
'smoothing_factor',
|
||||||
'top_p',
|
'top_p',
|
||||||
'min_p',
|
'min_p',
|
||||||
'top_k',
|
'top_k',
|
||||||
@ -233,6 +234,7 @@ loaders_samplers = {
|
|||||||
'dynatemp_low',
|
'dynatemp_low',
|
||||||
'dynatemp_high',
|
'dynatemp_high',
|
||||||
'dynatemp_exponent',
|
'dynatemp_exponent',
|
||||||
|
'smoothing_factor',
|
||||||
'top_p',
|
'top_p',
|
||||||
'min_p',
|
'min_p',
|
||||||
'top_k',
|
'top_k',
|
||||||
@ -289,6 +291,7 @@ loaders_samplers = {
|
|||||||
'dynatemp_low',
|
'dynatemp_low',
|
||||||
'dynatemp_high',
|
'dynatemp_high',
|
||||||
'dynatemp_exponent',
|
'dynatemp_exponent',
|
||||||
|
'smoothing_factor',
|
||||||
'top_p',
|
'top_p',
|
||||||
'min_p',
|
'min_p',
|
||||||
'top_k',
|
'top_k',
|
||||||
|
@ -17,6 +17,7 @@ def default_preset():
|
|||||||
'dynatemp_low': 1,
|
'dynatemp_low': 1,
|
||||||
'dynatemp_high': 1,
|
'dynatemp_high': 1,
|
||||||
'dynatemp_exponent': 1,
|
'dynatemp_exponent': 1,
|
||||||
|
'smoothing_factor': 0,
|
||||||
'top_p': 1,
|
'top_p': 1,
|
||||||
'min_p': 0,
|
'min_p': 0,
|
||||||
'top_k': 0,
|
'top_k': 0,
|
||||||
|
@ -15,8 +15,12 @@ from modules import shared
|
|||||||
global_scores = None
|
global_scores = None
|
||||||
|
|
||||||
|
|
||||||
class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
|
class ModifiedTemperatureLogitsWarper(LogitsWarper):
|
||||||
def __init__(self, temperature: float, dynamic_temperature: bool, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float):
|
'''
|
||||||
|
Based on the original Transformers temperature logits warper, this
|
||||||
|
adds support for dynamic temperature and quadratic sampling.
|
||||||
|
'''
|
||||||
|
def __init__(self, temperature: float, dynamic_temperature: bool, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float, smoothing_factor: float):
|
||||||
if not isinstance(temperature, float) or not (temperature > 0):
|
if not isinstance(temperature, float) or not (temperature > 0):
|
||||||
except_msg = (
|
except_msg = (
|
||||||
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
|
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
|
||||||
@ -32,16 +36,27 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
|
|||||||
self.dynatemp_low = dynatemp_low
|
self.dynatemp_low = dynatemp_low
|
||||||
self.dynatemp_high = dynatemp_high
|
self.dynatemp_high = dynatemp_high
|
||||||
self.dynatemp_exponent = dynatemp_exponent
|
self.dynatemp_exponent = dynatemp_exponent
|
||||||
|
self.smoothing_factor = smoothing_factor
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
|
||||||
# Regular temperature
|
# Quadratic sampling
|
||||||
if not self.dynamic_temperature:
|
if self.smoothing_factor > 0:
|
||||||
scores = scores / self.temperature
|
|
||||||
return scores
|
# Compute the maximum logit value
|
||||||
|
max_logit = scores.max()
|
||||||
|
|
||||||
|
# Apply the quadratic transformation
|
||||||
|
transformed_logits = -(self.smoothing_factor * (scores - max_logit)**2) + max_logit
|
||||||
|
|
||||||
|
# No need to print the top 5 logits since this is not required
|
||||||
|
# print("Original top 5 logits: ", torch.topk(scores, 5))
|
||||||
|
# print("New top 5 logits: ", torch.topk(transformed_logits, 5))
|
||||||
|
|
||||||
|
return transformed_logits
|
||||||
|
|
||||||
# Dynamic temperature
|
# Dynamic temperature
|
||||||
else:
|
elif self.dynamic_temperature:
|
||||||
min_temp = self.dynatemp_low
|
min_temp = self.dynatemp_low
|
||||||
max_temp = self.dynatemp_high
|
max_temp = self.dynatemp_high
|
||||||
exponent_val = self.dynatemp_exponent
|
exponent_val = self.dynatemp_exponent
|
||||||
@ -88,6 +103,11 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
|
|||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
# Regular temperature
|
||||||
|
else:
|
||||||
|
scores = scores / self.temperature
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class MinPLogitsWarper(LogitsWarper):
|
class MinPLogitsWarper(LogitsWarper):
|
||||||
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
@ -286,7 +306,7 @@ def get_logits_warper_patch(self, generation_config):
|
|||||||
generation_config.temperature = float(generation_config.temperature)
|
generation_config.temperature = float(generation_config.temperature)
|
||||||
|
|
||||||
temperature = generation_config.temperature
|
temperature = generation_config.temperature
|
||||||
if generation_config.dynamic_temperature:
|
if generation_config.dynamic_temperature or generation_config.smoothing_factor > 0:
|
||||||
# Make sure TemperatureLogitsWarper will be created by temporarily
|
# Make sure TemperatureLogitsWarper will be created by temporarily
|
||||||
# setting temperature to a value != 1.
|
# setting temperature to a value != 1.
|
||||||
generation_config.temperature = 1.1
|
generation_config.temperature = 1.1
|
||||||
@ -294,12 +314,13 @@ def get_logits_warper_patch(self, generation_config):
|
|||||||
warpers = self._get_logits_warper_old(generation_config)
|
warpers = self._get_logits_warper_old(generation_config)
|
||||||
for i in range(len(warpers)):
|
for i in range(len(warpers)):
|
||||||
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
|
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
|
||||||
warpers[i] = TemperatureLogitsWarperWithDynatemp(
|
warpers[i] = ModifiedTemperatureLogitsWarper(
|
||||||
temperature,
|
temperature,
|
||||||
generation_config.dynamic_temperature,
|
generation_config.dynamic_temperature,
|
||||||
generation_config.dynatemp_low,
|
generation_config.dynatemp_low,
|
||||||
generation_config.dynatemp_high,
|
generation_config.dynatemp_high,
|
||||||
generation_config.dynatemp_exponent
|
generation_config.dynatemp_exponent,
|
||||||
|
generation_config.smoothing_factor
|
||||||
)
|
)
|
||||||
|
|
||||||
warpers_to_add = LogitsProcessorList()
|
warpers_to_add = LogitsProcessorList()
|
||||||
@ -328,7 +349,7 @@ def get_logits_warper_patch(self, generation_config):
|
|||||||
if generation_config.temperature_last:
|
if generation_config.temperature_last:
|
||||||
temperature_idx = None
|
temperature_idx = None
|
||||||
for i in range(len(warpers)):
|
for i in range(len(warpers)):
|
||||||
if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'TemperatureLogitsWarperWithDynatemp']:
|
if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'ModifiedTemperatureLogitsWarper']:
|
||||||
temperature_idx = i
|
temperature_idx = i
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -352,8 +373,7 @@ def get_logits_processor_patch(self, **kwargs):
|
|||||||
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
|
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
|
||||||
do_rep_pen_hijack = (repetition_penalty > 1) or (presence_penalty != 0) or (frequency_penalty != 0)
|
do_rep_pen_hijack = (repetition_penalty > 1) or (presence_penalty != 0) or (frequency_penalty != 0)
|
||||||
if do_rep_pen_hijack:
|
if do_rep_pen_hijack:
|
||||||
# Make sure that a RepetitionPenaltyLogitsProcessor will be created
|
kwargs['generation_config'].repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created
|
||||||
kwargs['generation_config'].repetition_penalty = 1.1 # must set to some value > 1
|
|
||||||
|
|
||||||
result = self._get_logits_processor_old(**kwargs)
|
result = self._get_logits_processor_old(**kwargs)
|
||||||
|
|
||||||
@ -372,6 +392,7 @@ def generation_config_init_patch(self, **kwargs):
|
|||||||
self.dynatemp_low = kwargs.pop("dynatemp_low", 1)
|
self.dynatemp_low = kwargs.pop("dynatemp_low", 1)
|
||||||
self.dynatemp_high = kwargs.pop("dynatemp_high", 1)
|
self.dynatemp_high = kwargs.pop("dynatemp_high", 1)
|
||||||
self.dynatemp_exponent = kwargs.pop("dynatemp_exponent", 1)
|
self.dynatemp_exponent = kwargs.pop("dynatemp_exponent", 1)
|
||||||
|
self.smoothing_factor = kwargs.pop("smoothing_factor", 0.0)
|
||||||
self.tfs = kwargs.pop("tfs", 1.0)
|
self.tfs = kwargs.pop("tfs", 1.0)
|
||||||
self.top_a = kwargs.pop("top_a", 0.0)
|
self.top_a = kwargs.pop("top_a", 0.0)
|
||||||
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
|
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
|
||||||
|
@ -285,8 +285,9 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
|
|||||||
|
|
||||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||||
generate_params = {}
|
generate_params = {}
|
||||||
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']:
|
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']:
|
||||||
generate_params[k] = state[k]
|
if k in state:
|
||||||
|
generate_params[k] = state[k]
|
||||||
|
|
||||||
if state['negative_prompt'] != '':
|
if state['negative_prompt'] != '':
|
||||||
generate_params['negative_prompt_ids'] = encode(state['negative_prompt'])
|
generate_params['negative_prompt_ids'] = encode(state['negative_prompt'])
|
||||||
|
@ -120,6 +120,7 @@ def list_interface_input_elements():
|
|||||||
'dynatemp_low',
|
'dynatemp_low',
|
||||||
'dynatemp_high',
|
'dynatemp_high',
|
||||||
'dynatemp_exponent',
|
'dynatemp_exponent',
|
||||||
|
'smoothing_factor',
|
||||||
'top_p',
|
'top_p',
|
||||||
'min_p',
|
'min_p',
|
||||||
'top_k',
|
'top_k',
|
||||||
|
@ -49,6 +49,7 @@ def create_ui(default_preset):
|
|||||||
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
|
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
|
||||||
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
|
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
|
||||||
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
|
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
|
||||||
|
shared.gradio['smoothing_factor'] = gr.Slider(0.0, 10.0, value=generate_params['smoothing_factor'], step=0.01, label='smoothing_factor', info='Replaces temperature with Quadratic Sampling.')
|
||||||
shared.gradio['dynamic_temperature'] = gr.Checkbox(value=generate_params['dynamic_temperature'], label='dynamic_temperature')
|
shared.gradio['dynamic_temperature'] = gr.Checkbox(value=generate_params['dynamic_temperature'], label='dynamic_temperature')
|
||||||
shared.gradio['dynatemp_low'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_low'], step=0.01, label='dynatemp_low', visible=generate_params['dynamic_temperature'])
|
shared.gradio['dynatemp_low'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_low'], step=0.01, label='dynatemp_low', visible=generate_params['dynamic_temperature'])
|
||||||
shared.gradio['dynatemp_high'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_high'], step=0.01, label='dynatemp_high', visible=generate_params['dynamic_temperature'])
|
shared.gradio['dynatemp_high'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_high'], step=0.01, label='dynatemp_high', visible=generate_params['dynamic_temperature'])
|
||||||
|
Loading…
Reference in New Issue
Block a user