mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Add temperature_last parameter (#4472)
This commit is contained in:
parent
1ab8700d94
commit
aa5d671579
@ -25,6 +25,7 @@ def build_parameters(body, chat=False):
|
|||||||
'max_tokens_second': int(body.get('max_tokens_second', 0)),
|
'max_tokens_second': int(body.get('max_tokens_second', 0)),
|
||||||
'do_sample': bool(body.get('do_sample', True)),
|
'do_sample': bool(body.get('do_sample', True)),
|
||||||
'temperature': float(body.get('temperature', 0.5)),
|
'temperature': float(body.get('temperature', 0.5)),
|
||||||
|
'temperature_last': bool(body.get('temperature_last', False)),
|
||||||
'top_p': float(body.get('top_p', 1)),
|
'top_p': float(body.get('top_p', 1)),
|
||||||
'min_p': float(body.get('min_p', 0)),
|
'min_p': float(body.get('min_p', 0)),
|
||||||
'typical_p': float(body.get('typical_p', body.get('typical', 1))),
|
'typical_p': float(body.get('typical_p', body.get('typical', 1))),
|
||||||
|
@ -148,6 +148,7 @@ loaders_and_params = OrderedDict({
|
|||||||
loaders_samplers = {
|
loaders_samplers = {
|
||||||
'Transformers': {
|
'Transformers': {
|
||||||
'temperature',
|
'temperature',
|
||||||
|
'temperature_last',
|
||||||
'top_p',
|
'top_p',
|
||||||
'min_p',
|
'min_p',
|
||||||
'top_k',
|
'top_k',
|
||||||
@ -184,6 +185,7 @@ loaders_samplers = {
|
|||||||
},
|
},
|
||||||
'ExLlama_HF': {
|
'ExLlama_HF': {
|
||||||
'temperature',
|
'temperature',
|
||||||
|
'temperature_last',
|
||||||
'top_p',
|
'top_p',
|
||||||
'min_p',
|
'min_p',
|
||||||
'top_k',
|
'top_k',
|
||||||
@ -245,6 +247,7 @@ loaders_samplers = {
|
|||||||
},
|
},
|
||||||
'ExLlamav2_HF': {
|
'ExLlamav2_HF': {
|
||||||
'temperature',
|
'temperature',
|
||||||
|
'temperature_last',
|
||||||
'top_p',
|
'top_p',
|
||||||
'min_p',
|
'min_p',
|
||||||
'top_k',
|
'top_k',
|
||||||
@ -277,6 +280,7 @@ loaders_samplers = {
|
|||||||
},
|
},
|
||||||
'AutoGPTQ': {
|
'AutoGPTQ': {
|
||||||
'temperature',
|
'temperature',
|
||||||
|
'temperature_last',
|
||||||
'top_p',
|
'top_p',
|
||||||
'min_p',
|
'min_p',
|
||||||
'top_k',
|
'top_k',
|
||||||
@ -313,6 +317,7 @@ loaders_samplers = {
|
|||||||
},
|
},
|
||||||
'GPTQ-for-LLaMa': {
|
'GPTQ-for-LLaMa': {
|
||||||
'temperature',
|
'temperature',
|
||||||
|
'temperature_last',
|
||||||
'top_p',
|
'top_p',
|
||||||
'min_p',
|
'min_p',
|
||||||
'top_k',
|
'top_k',
|
||||||
@ -365,6 +370,7 @@ loaders_samplers = {
|
|||||||
},
|
},
|
||||||
'llamacpp_HF': {
|
'llamacpp_HF': {
|
||||||
'temperature',
|
'temperature',
|
||||||
|
'temperature_last',
|
||||||
'top_p',
|
'top_p',
|
||||||
'min_p',
|
'min_p',
|
||||||
'top_k',
|
'top_k',
|
||||||
@ -404,6 +410,7 @@ loaders_samplers = {
|
|||||||
},
|
},
|
||||||
'AutoAWQ': {
|
'AutoAWQ': {
|
||||||
'temperature',
|
'temperature',
|
||||||
|
'temperature_last',
|
||||||
'top_p',
|
'top_p',
|
||||||
'min_p',
|
'min_p',
|
||||||
'top_k',
|
'top_k',
|
||||||
|
@ -8,6 +8,7 @@ def default_preset():
|
|||||||
return {
|
return {
|
||||||
'do_sample': True,
|
'do_sample': True,
|
||||||
'temperature': 1,
|
'temperature': 1,
|
||||||
|
'temperature_last': False,
|
||||||
'top_p': 1,
|
'top_p': 1,
|
||||||
'min_p': 0,
|
'min_p': 0,
|
||||||
'top_k': 0,
|
'top_k': 0,
|
||||||
|
@ -12,6 +12,7 @@ from transformers.generation.logits_process import (
|
|||||||
|
|
||||||
global_scores = None
|
global_scores = None
|
||||||
|
|
||||||
|
|
||||||
class MinPLogitsWarper(LogitsWarper):
|
class MinPLogitsWarper(LogitsWarper):
|
||||||
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
if min_p < 0 or min_p > 1.0:
|
if min_p < 0 or min_p > 1.0:
|
||||||
@ -41,6 +42,7 @@ class MinPLogitsWarper(LogitsWarper):
|
|||||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class TailFreeLogitsWarper(LogitsWarper):
|
class TailFreeLogitsWarper(LogitsWarper):
|
||||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
tfs = float(tfs)
|
tfs = float(tfs)
|
||||||
@ -214,19 +216,36 @@ def get_logits_warper_patch(self, generation_config):
|
|||||||
if not isinstance(warper, TemperatureLogitsWarper):
|
if not isinstance(warper, TemperatureLogitsWarper):
|
||||||
warpers.remove(warper)
|
warpers.remove(warper)
|
||||||
else:
|
else:
|
||||||
if generation_config.tfs is not None and 0.0 <= generation_config.tfs <= 1.0:
|
if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0:
|
||||||
warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep))
|
warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep))
|
||||||
if generation_config.top_a is not None and 0.0 <= generation_config.top_a <= 1.0:
|
if generation_config.top_a is not None and 0.0 < generation_config.top_a <= 1.0:
|
||||||
warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep))
|
warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep))
|
||||||
if generation_config.min_p is not None and 0.0 <= generation_config.min_p <= 1.0:
|
if generation_config.min_p is not None and 0.0 < generation_config.min_p <= 1.0:
|
||||||
warpers_to_add.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
|
warpers_to_add.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
|
||||||
|
|
||||||
if warpers and isinstance(warpers[-1], LogitNormalization):
|
if len(warpers) > 0 and isinstance(warpers[-1], LogitNormalization):
|
||||||
warpers = warpers[:-1] + warpers_to_add + [warpers[-1]]
|
normalize = warpers.pop(-1)
|
||||||
else:
|
else:
|
||||||
warpers += warpers_to_add
|
normalize = None
|
||||||
|
|
||||||
|
warpers += warpers_to_add
|
||||||
|
if generation_config.temperature_last:
|
||||||
|
temperature_idx = None
|
||||||
|
for i in range(len(warpers)):
|
||||||
|
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
|
||||||
|
temperature_idx = i
|
||||||
|
break
|
||||||
|
|
||||||
|
if temperature_idx is not None:
|
||||||
|
warpers = warpers[:temperature_idx] + warpers[temperature_idx + 1:] + [warpers[temperature_idx]]
|
||||||
|
warpers = LogitsProcessorList(warpers)
|
||||||
|
|
||||||
|
if normalize is not None:
|
||||||
|
warpers.append(normalize)
|
||||||
|
|
||||||
warpers.append(SpyLogitsWarper())
|
warpers.append(SpyLogitsWarper())
|
||||||
|
# for i in range(len(warpers)):
|
||||||
|
# print(warpers[i].__class__.__name__)
|
||||||
return warpers
|
return warpers
|
||||||
|
|
||||||
|
|
||||||
@ -261,6 +280,7 @@ def generation_config_init_patch(self, **kwargs):
|
|||||||
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
|
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
|
||||||
self.presence_penalty = kwargs.pop("presence_penalty", 0)
|
self.presence_penalty = kwargs.pop("presence_penalty", 0)
|
||||||
self.frequency_penalty = kwargs.pop("frequency_penalty", 0)
|
self.frequency_penalty = kwargs.pop("frequency_penalty", 0)
|
||||||
|
self.temperature_last = kwargs.pop("temperature_last", False)
|
||||||
|
|
||||||
|
|
||||||
def hijack_samplers():
|
def hijack_samplers():
|
||||||
|
@ -274,7 +274,7 @@ def apply_stopping_strings(reply, all_stop_strings):
|
|||||||
|
|
||||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||||
generate_params = {}
|
generate_params = {}
|
||||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'min_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
|
for k in ['max_new_tokens', 'do_sample', 'temperature', 'temperature_last', 'top_p', 'min_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
|
||||||
generate_params[k] = state[k]
|
generate_params[k] = state[k]
|
||||||
|
|
||||||
if state['negative_prompt'] != '':
|
if state['negative_prompt'] != '':
|
||||||
|
@ -104,6 +104,7 @@ def list_interface_input_elements():
|
|||||||
'max_tokens_second',
|
'max_tokens_second',
|
||||||
'seed',
|
'seed',
|
||||||
'temperature',
|
'temperature',
|
||||||
|
'temperature_last',
|
||||||
'top_p',
|
'top_p',
|
||||||
'min_p',
|
'min_p',
|
||||||
'top_k',
|
'top_k',
|
||||||
|
@ -48,6 +48,7 @@ def create_ui(default_preset):
|
|||||||
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
|
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
|
||||||
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
|
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
|
||||||
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
|
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
|
||||||
|
shared.gradio['temperature_last'] = gr.Checkbox(value=generate_params['temperature_last'], label='temperature_last', info='Makes temperature the last sampler instead of the first.')
|
||||||
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
||||||
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
|
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
|
||||||
with gr.Accordion('Other parameters', open=False):
|
with gr.Accordion('Other parameters', open=False):
|
||||||
|
Loading…
Reference in New Issue
Block a user