mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Add dynamic temperature to the random preset button
This commit is contained in:
parent
0d07b3a6a1
commit
4365fb890f
@ -1,4 +1,5 @@
|
|||||||
import functools
|
import functools
|
||||||
|
import pprint
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -6,6 +7,7 @@ import yaml
|
|||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.loaders import loaders_samplers
|
from modules.loaders import loaders_samplers
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
|
||||||
def default_preset():
|
def default_preset():
|
||||||
@ -81,7 +83,8 @@ def random_preset(state):
|
|||||||
'eta_cutoff': [3, 6, 9, 12, 15, 18],
|
'eta_cutoff': [3, 6, 9, 12, 15, 18],
|
||||||
},
|
},
|
||||||
'flatten_distribution': {
|
'flatten_distribution': {
|
||||||
'temperature': [0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0],
|
'temperature': [0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0, 3.0, 5.0],
|
||||||
|
'dynamic_temperature_low': [0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0, 3.0],
|
||||||
},
|
},
|
||||||
'repetition': {
|
'repetition': {
|
||||||
'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25],
|
'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25],
|
||||||
@ -90,20 +93,51 @@ def random_preset(state):
|
|||||||
},
|
},
|
||||||
'other': {
|
'other': {
|
||||||
'temperature_last': [True, False],
|
'temperature_last': [True, False],
|
||||||
|
'dynamic_temperature': [True, False],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
generate_params = default_preset()
|
generate_params = default_preset()
|
||||||
|
defaults = default_preset()
|
||||||
|
|
||||||
for cat in params_and_values:
|
for cat in params_and_values:
|
||||||
choices = list(params_and_values[cat].keys())
|
choices = list(params_and_values[cat].keys())
|
||||||
if shared.args.loader is not None:
|
if shared.args.loader is not None:
|
||||||
choices = [x for x in choices if x in loaders_samplers[shared.args.loader]]
|
choices = [x for x in choices if x in loaders_samplers[shared.args.loader]]
|
||||||
|
|
||||||
if len(choices) > 0:
|
if len(choices) > 0:
|
||||||
choice = random.choice(choices)
|
if cat == 'other':
|
||||||
generate_params[choice] = random.choice(params_and_values[cat][choice])
|
N = random.randint(1, len(choices))
|
||||||
|
maybe_multiple_choices = random.sample(choices, N)
|
||||||
|
for choice in maybe_multiple_choices:
|
||||||
|
generate_params[choice] = random.choice(params_and_values[cat][choice])
|
||||||
|
else:
|
||||||
|
choice = random.choice(choices)
|
||||||
|
generate_params[choice] = random.choice(params_and_values[cat][choice])
|
||||||
|
|
||||||
|
# If using dynamic temperature, sample the high/low values simultaneously.
|
||||||
|
# If necessary, resample until the low is lower than the high
|
||||||
|
if generate_params['dynamic_temperature']:
|
||||||
|
generate_params['dynamic_temperature_low'] = random.choice(params_and_values['flatten_distribution']['dynamic_temperature_low'])
|
||||||
|
generate_params['temperature'] = random.choice(params_and_values['flatten_distribution']['temperature'])
|
||||||
|
while generate_params['dynamic_temperature_low'] >= generate_params['temperature']:
|
||||||
|
generate_params['dynamic_temperature_low'] = random.choice(params_and_values['flatten_distribution']['dynamic_temperature_low'])
|
||||||
|
generate_params['temperature'] = random.choice(params_and_values['flatten_distribution']['temperature'])
|
||||||
|
elif 'dynamic_temperature_low' in generate_params:
|
||||||
|
generate_params['dynamic_temperature_low'] = defaults['dynamic_temperature_low']
|
||||||
|
|
||||||
state.update(generate_params)
|
state.update(generate_params)
|
||||||
|
|
||||||
|
diff = {}
|
||||||
|
# Remove entries that are identical to the defaults
|
||||||
|
for k in list(generate_params.keys()):
|
||||||
|
if generate_params[k] != defaults[k]:
|
||||||
|
diff[k] = generate_params[k]
|
||||||
|
|
||||||
|
logger.info("GENERATED_PRESET=")
|
||||||
|
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(diff)
|
||||||
|
print()
|
||||||
|
|
||||||
return state, *[generate_params[k] for k in presets_params()]
|
return state, *[generate_params[k] for k in presets_params()]
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user