Add dynamic temperature to the random preset button

This commit is contained in:
oobabooga 2024-01-07 13:07:32 -08:00
parent 0d07b3a6a1
commit 4365fb890f

View File

@ -1,4 +1,5 @@
import functools import functools
import pprint
import random import random
from pathlib import Path from pathlib import Path
@ -6,6 +7,7 @@ import yaml
from modules import shared from modules import shared
from modules.loaders import loaders_samplers from modules.loaders import loaders_samplers
from modules.logging_colors import logger
def default_preset(): def default_preset():
@ -81,7 +83,8 @@ def random_preset(state):
'eta_cutoff': [3, 6, 9, 12, 15, 18], 'eta_cutoff': [3, 6, 9, 12, 15, 18],
}, },
'flatten_distribution': { 'flatten_distribution': {
'temperature': [0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0], 'temperature': [0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0, 3.0, 5.0],
'dynamic_temperature_low': [0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0, 3.0],
}, },
'repetition': { 'repetition': {
'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25], 'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25],
@ -90,20 +93,51 @@ def random_preset(state):
}, },
'other': { 'other': {
'temperature_last': [True, False], 'temperature_last': [True, False],
'dynamic_temperature': [True, False],
} }
} }
generate_params = default_preset() generate_params = default_preset()
defaults = default_preset()
for cat in params_and_values: for cat in params_and_values:
choices = list(params_and_values[cat].keys()) choices = list(params_and_values[cat].keys())
if shared.args.loader is not None: if shared.args.loader is not None:
choices = [x for x in choices if x in loaders_samplers[shared.args.loader]] choices = [x for x in choices if x in loaders_samplers[shared.args.loader]]
if len(choices) > 0: if len(choices) > 0:
if cat == 'other':
N = random.randint(1, len(choices))
maybe_multiple_choices = random.sample(choices, N)
for choice in maybe_multiple_choices:
generate_params[choice] = random.choice(params_and_values[cat][choice])
else:
choice = random.choice(choices) choice = random.choice(choices)
generate_params[choice] = random.choice(params_and_values[cat][choice]) generate_params[choice] = random.choice(params_and_values[cat][choice])
# If using dynamic temperature, sample the high/low values simultaneously.
# If necessary, resample until the low is lower than the high
if generate_params['dynamic_temperature']:
generate_params['dynamic_temperature_low'] = random.choice(params_and_values['flatten_distribution']['dynamic_temperature_low'])
generate_params['temperature'] = random.choice(params_and_values['flatten_distribution']['temperature'])
while generate_params['dynamic_temperature_low'] >= generate_params['temperature']:
generate_params['dynamic_temperature_low'] = random.choice(params_and_values['flatten_distribution']['dynamic_temperature_low'])
generate_params['temperature'] = random.choice(params_and_values['flatten_distribution']['temperature'])
elif 'dynamic_temperature_low' in generate_params:
generate_params['dynamic_temperature_low'] = defaults['dynamic_temperature_low']
state.update(generate_params) state.update(generate_params)
diff = {}
# Remove entries that are identical to the defaults
for k in list(generate_params.keys()):
if generate_params[k] != defaults[k]:
diff[k] = generate_params[k]
logger.info("GENERATED_PRESET=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(diff)
print()
return state, *[generate_params[k] for k in presets_params()] return state, *[generate_params[k] for k in presets_params()]