mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Dynamic Temperature HF loader support (#5174)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
3eca20c015
commit
48327cc5c4
@ -54,6 +54,7 @@ For more information about the parameters, the [transformers documentation](http
|
||||
* **mirostat_mode**: Activates the Mirostat sampling technique. It aims to control perplexity during sampling. See the [paper](https://arxiv.org/abs/2007.14966).
|
||||
* **mirostat_tau**: No idea, see the paper for details. According to the Preset Arena, 8 is a good value.
|
||||
* **mirostat_eta**: No idea, see the paper for details. According to the Preset Arena, 0.1 is a good value.
|
||||
* **dynatemp**: Dynamic Temperature is activated when this parameter is greater than 0. The temperature range is determined by adding and subtracting dynatemp from the current temperature.
|
||||
* **temperature_last**: Makes temperature the last sampler instead of the first. With this, you can remove low probability tokens with a sampler like min_p and then use a high temperature to make the model creative without losing coherency.
|
||||
* **do_sample**: When unchecked, sampling is entirely disabled, and greedy decoding is used instead (the most likely token is always picked).
|
||||
* **Seed**: Set the Pytorch seed to this number. Note that some loaders do not use Pytorch (notably llama.cpp), and others are not deterministic (notably ExLlama v1 and v2). For these loaders, the seed has no effect.
|
||||
|
17
extensions/dynatemp_with_range/README.md
Normal file
17
extensions/dynatemp_with_range/README.md
Normal file
@ -0,0 +1,17 @@
|
||||
# dynatemp_with_range
|
||||
|
||||
This extension makes it possible to set the minimum and maximum temperatures for dynamic temperature explicitly.
|
||||
|
||||
For instance, you can directly set
|
||||
|
||||
```
|
||||
min_T = 0.1
|
||||
max_T = 3
|
||||
```
|
||||
|
||||
instead of having to convert that to
|
||||
|
||||
```
|
||||
T = 1.55
|
||||
dynatemp = 1.45
|
||||
```
|
50
extensions/dynatemp_with_range/script.py
Normal file
50
extensions/dynatemp_with_range/script.py
Normal file
@ -0,0 +1,50 @@
|
||||
import gradio as gr
|
||||
|
||||
params = {
|
||||
"activate": True,
|
||||
"minimum_temperature": 0.1,
|
||||
"maximum_temperature": 2,
|
||||
}
|
||||
|
||||
def convert_to_dynatemp():
|
||||
temperature = 0.5 * (params["minimum_temperature"] + params["maximum_temperature"])
|
||||
dynatemp = params["maximum_temperature"] - temperature
|
||||
return temperature, dynatemp
|
||||
|
||||
|
||||
def state_modifier(state):
|
||||
"""
|
||||
Modifies the state variable, which is a dictionary containing the input
|
||||
values in the UI like sliders and checkboxes.
|
||||
"""
|
||||
|
||||
if params["activate"]:
|
||||
temperature, dynatemp = convert_to_dynatemp()
|
||||
|
||||
state["temperature"] = temperature
|
||||
state["dynatemp"] = dynatemp
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def generate_info():
|
||||
temperature, dynatemp = convert_to_dynatemp()
|
||||
return f"The combination above is equivalent to: T={temperature:.2f}, dynatemp={dynatemp:.2f}"
|
||||
|
||||
|
||||
def ui():
|
||||
activate = gr.Checkbox(value=params['activate'], label='Activate Dynamic Temperature Range', info='When checked, the default temperature/dynatemp parameters are ignored and the parameters below are used instead.')
|
||||
with gr.Row():
|
||||
minimum_temperature = gr.Slider(0, 5, step=0.01, label="Minimum temperature", value=params["minimum_temperature"], interactive=True)
|
||||
maximum_temperature = gr.Slider(0, 5, step=0.01, label="Maximum temperature", value=params["maximum_temperature"], interactive=True)
|
||||
|
||||
info = gr.HTML(generate_info())
|
||||
|
||||
activate.change(lambda x: params.update({"activate": x}), activate, None)
|
||||
minimum_temperature.change(
|
||||
lambda x: params.update({"minimum_temperature": x}), minimum_temperature, None).then(
|
||||
generate_info, None, info, show_progress=False)
|
||||
|
||||
maximum_temperature.change(
|
||||
lambda x: params.update({"maximum_temperature": x}), maximum_temperature, None).then(
|
||||
generate_info, None, info, show_progress=False)
|
@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
|
||||
class GenerationOptions(BaseModel):
|
||||
preset: str | None = Field(default=None, description="The name of a file under text-generation-webui/presets (without the .yaml extension). The sampling parameters that get overwritten by this option are the keys in the default_preset() function in modules/presets.py.")
|
||||
min_p: float = 0
|
||||
dynatemp: float = 0
|
||||
top_k: int = 0
|
||||
repetition_penalty: float = 1
|
||||
repetition_penalty_range: int = 1024
|
||||
|
@ -10,6 +10,10 @@ from transformers import is_torch_xpu_available
|
||||
import modules.shared as shared
|
||||
|
||||
|
||||
class StopNowException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
|
||||
def __init__(self):
|
||||
transformers.StoppingCriteria.__init__(self)
|
||||
@ -49,13 +53,13 @@ class Iteratorize:
|
||||
|
||||
def _callback(val):
|
||||
if self.stop_now or shared.stop_everything:
|
||||
raise ValueError
|
||||
raise StopNowException
|
||||
self.q.put(val)
|
||||
|
||||
def gentask():
|
||||
try:
|
||||
ret = self.mfunc(callback=_callback, *args, **self.kwargs)
|
||||
except ValueError:
|
||||
except StopNowException:
|
||||
pass
|
||||
except:
|
||||
traceback.print_exc()
|
||||
|
@ -144,6 +144,9 @@ class LlamacppHF(PreTrainedModel):
|
||||
self.model.n_tokens = longest_prefix
|
||||
if len(seq_tensor) - longest_prefix > 0:
|
||||
self.model.eval(seq[longest_prefix:])
|
||||
else:
|
||||
self.model.n_tokens -= 1
|
||||
self.model.eval([seq[-1]])
|
||||
|
||||
if reset:
|
||||
self.model.reset()
|
||||
|
@ -155,6 +155,7 @@ def transformers_samplers():
|
||||
return {
|
||||
'temperature',
|
||||
'temperature_last',
|
||||
'dynatemp',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'top_k',
|
||||
@ -220,6 +221,7 @@ loaders_samplers = {
|
||||
'ExLlamav2_HF': {
|
||||
'temperature',
|
||||
'temperature_last',
|
||||
'dynatemp',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'top_k',
|
||||
@ -272,6 +274,7 @@ loaders_samplers = {
|
||||
'llamacpp_HF': {
|
||||
'temperature',
|
||||
'temperature_last',
|
||||
'dynatemp',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'top_k',
|
||||
|
@ -8,7 +8,7 @@ from modules.text_generation import generate_reply
|
||||
global_scores = None
|
||||
|
||||
|
||||
def get_next_logits(prompt, state, use_samplers, previous, top_logits=50, return_dict=False):
|
||||
def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return_dict=False):
|
||||
if shared.model is None:
|
||||
logger.error("No model is loaded! Select one in the Model tab.")
|
||||
return 'Error: No model is loaded1 Select one in the Model tab.', previous
|
||||
|
@ -12,6 +12,7 @@ def default_preset():
|
||||
return {
|
||||
'temperature': 1,
|
||||
'temperature_last': False,
|
||||
'dynatemp': 0,
|
||||
'top_p': 1,
|
||||
'min_p': 0,
|
||||
'top_k': 0,
|
||||
|
@ -10,9 +10,84 @@ from transformers.generation.logits_process import (
|
||||
TemperatureLogitsWarper
|
||||
)
|
||||
|
||||
from modules import shared
|
||||
|
||||
global_scores = None
|
||||
|
||||
|
||||
class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
|
||||
def __init__(self, temperature: float, dynatemp: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
if not isinstance(temperature, float) or not (temperature > 0):
|
||||
except_msg = (
|
||||
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
|
||||
"scores will be invalid."
|
||||
)
|
||||
if isinstance(temperature, float) and temperature == 0.0:
|
||||
except_msg += " If you're looking for greedy decoding strategies, set `do_sample=False`."
|
||||
|
||||
raise ValueError(except_msg)
|
||||
|
||||
self.temperature = temperature
|
||||
self.dynatemp = dynatemp
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
# Regular temperature
|
||||
if self.dynatemp == 0:
|
||||
scores = scores / self.temperature
|
||||
return scores
|
||||
|
||||
# Dynamic temperature
|
||||
else:
|
||||
min_temp = max(0.0, self.temperature - self.dynatemp)
|
||||
max_temp = self.temperature + self.dynatemp
|
||||
exponent_val = 1.0
|
||||
|
||||
# Convert logits to probabilities
|
||||
probs = torch.softmax(scores, dim=-1)
|
||||
|
||||
# Calculate entropy of the softmax probabilities
|
||||
entropy = -1.0 * torch.where(probs > 0, probs * torch.log(probs), torch.zeros_like(probs)).sum()
|
||||
|
||||
# Guard against future possible division by zero
|
||||
entropy = max(entropy, torch.tensor(1e-10)) # Ensures entropy is slightly greater than 0
|
||||
|
||||
# Any logits which are not -Infinity will be considered for calculating max entropy.
|
||||
num_valid_tokens = torch.sum(scores > -float('inf')).item()
|
||||
|
||||
# Now, calculate the max entropy by using only the valid tokens' count
|
||||
max_entropy = math.log(num_valid_tokens)
|
||||
|
||||
# Guard against future possible division by zero
|
||||
max_entropy = max_entropy if max_entropy > 0.0 else 1e-10
|
||||
|
||||
# Normalize the entropy
|
||||
normalized_entropy = entropy / max_entropy
|
||||
|
||||
# Map the normalized entropy to the desired temperature range using the power function
|
||||
dyn_temp = min_temp + (max_temp - min_temp) * (normalized_entropy.pow(exponent_val))
|
||||
|
||||
# Apply the dynamically calculated temperature scaling
|
||||
scores = scores / dyn_temp
|
||||
|
||||
# print("----------------------\nTemperature from generation_config:", self.temperature)
|
||||
# print("min_temp:", min_temp)
|
||||
# print("max_temp:", max_temp)
|
||||
# print("Entropy:", entropy.item())
|
||||
# print("Max Possible Entropy considering valid tokens only:", max_entropy)
|
||||
# print("Normalized Entropy:", normalized_entropy.item())
|
||||
# print("Dynamic Temperature (dyn_temp):", dyn_temp.item())
|
||||
# print("----------------------")
|
||||
|
||||
# max_prob_token_id = torch.argmax(scores, dim=-1) # Get the token ID with the highest probability
|
||||
# max_prob_token = shared.tokenizer.convert_ids_to_tokens(int(max_prob_token_id)) # Convert ID to token
|
||||
# print("--- T=", float(dyn_temp), "token=", max_prob_token, "min=", min_temp, "max=", max_temp)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class MinPLogitsWarper(LogitsWarper):
|
||||
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
if min_p < 0 or min_p > 1.0:
|
||||
@ -198,14 +273,28 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
||||
# presence_penalty and frequency_penalty
|
||||
raw_presence_penalty = (counts > 0).to(scores.dtype)
|
||||
raw_frequency_penalty = counts.to(scores.dtype)
|
||||
additive_penalty = raw_presence_penalty*self.presence_penalty + raw_frequency_penalty*self.frequency_penalty
|
||||
additive_penalty = raw_presence_penalty * self.presence_penalty + raw_frequency_penalty * self.frequency_penalty
|
||||
scores_row.scatter_add_(0, unique_ids, -additive_penalty)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_warper_patch(self, generation_config):
|
||||
# Make sure that temperature is float and not int
|
||||
if isinstance(generation_config.temperature, int):
|
||||
generation_config.temperature = float(generation_config.temperature)
|
||||
|
||||
temperature = generation_config.temperature
|
||||
if generation_config.dynatemp > 0:
|
||||
# Make sure TemperatureLogitsWarper will be created by temporarily
|
||||
# setting temperature to a value != 1.
|
||||
generation_config.temperature = 1.1
|
||||
|
||||
warpers = self._get_logits_warper_old(generation_config)
|
||||
for i in range(len(warpers)):
|
||||
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
|
||||
warpers[i] = TemperatureLogitsWarperWithDynatemp(temperature, generation_config.dynatemp)
|
||||
|
||||
warpers_to_add = LogitsProcessorList()
|
||||
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
|
||||
|
||||
@ -232,18 +321,18 @@ def get_logits_warper_patch(self, generation_config):
|
||||
if generation_config.temperature_last:
|
||||
temperature_idx = None
|
||||
for i in range(len(warpers)):
|
||||
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
|
||||
if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'TemperatureLogitsWarperWithDynatemp']:
|
||||
temperature_idx = i
|
||||
break
|
||||
|
||||
if temperature_idx is not None:
|
||||
warpers = warpers[:temperature_idx] + warpers[temperature_idx + 1:] + [warpers[temperature_idx]]
|
||||
warpers = LogitsProcessorList(warpers)
|
||||
warpers.append(warpers.pop(temperature_idx))
|
||||
|
||||
if normalize is not None:
|
||||
warpers.append(normalize)
|
||||
|
||||
warpers.append(SpyLogitsWarper())
|
||||
warpers = LogitsProcessorList(warpers)
|
||||
# for i in range(len(warpers)):
|
||||
# print(warpers[i].__class__.__name__)
|
||||
return warpers
|
||||
@ -272,6 +361,7 @@ def get_logits_processor_patch(self, **kwargs):
|
||||
def generation_config_init_patch(self, **kwargs):
|
||||
self.__init___old(**kwargs)
|
||||
self.min_p = kwargs.pop("min_p", 0.0)
|
||||
self.dynatemp = kwargs.pop("dynatemp", 0.0)
|
||||
self.tfs = kwargs.pop("tfs", 1.0)
|
||||
self.top_a = kwargs.pop("top_a", 0.0)
|
||||
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
|
||||
|
@ -283,7 +283,7 @@ def get_reply_from_output_ids(output_ids, state, starting_from=0):
|
||||
|
||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||
generate_params = {}
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'temperature_last', 'top_p', 'min_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'temperature_last', 'dynatemp', 'top_p', 'min_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
if state['negative_prompt'] != '':
|
||||
|
@ -115,6 +115,7 @@ def list_interface_input_elements():
|
||||
'seed',
|
||||
'temperature',
|
||||
'temperature_last',
|
||||
'dynatemp',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'top_k',
|
||||
|
@ -49,6 +49,7 @@ def create_ui(default_preset):
|
||||
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
|
||||
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
|
||||
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
|
||||
shared.gradio['dynatemp'] = gr.Slider(0, 5, value=generate_params['dynatemp'], step=0.01, label='dynatemp')
|
||||
shared.gradio['temperature_last'] = gr.Checkbox(value=generate_params['temperature_last'], label='temperature_last', info='Makes temperature the last sampler instead of the first.')
|
||||
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
||||
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
|
||||
|
4
presets/Dynamic Temperature.yaml
Normal file
4
presets/Dynamic Temperature.yaml
Normal file
@ -0,0 +1,4 @@
|
||||
temperature: 1.55
|
||||
temperature_last: true
|
||||
dynatemp: 1.45
|
||||
min_p: 0.05
|
Loading…
Reference in New Issue
Block a user