From 4365fb890f851e4ca9629a85b7792477cac983ae Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 7 Jan 2024 13:07:32 -0800 Subject: [PATCH] Add dynamic temperature to the random preset button --- modules/presets.py | 40 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/modules/presets.py b/modules/presets.py index 42ca7820..21334e0f 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -1,4 +1,5 @@ import functools +import pprint import random from pathlib import Path @@ -6,6 +7,7 @@ import yaml from modules import shared from modules.loaders import loaders_samplers +from modules.logging_colors import logger def default_preset(): @@ -81,7 +83,8 @@ def random_preset(state): 'eta_cutoff': [3, 6, 9, 12, 15, 18], }, 'flatten_distribution': { - 'temperature': [0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0], + 'temperature': [0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0, 3.0, 5.0], + 'dynamic_temperature_low': [0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0, 3.0], }, 'repetition': { 'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25], @@ -90,20 +93,51 @@ def random_preset(state): }, 'other': { 'temperature_last': [True, False], + 'dynamic_temperature': [True, False], } } generate_params = default_preset() + defaults = default_preset() + for cat in params_and_values: choices = list(params_and_values[cat].keys()) if shared.args.loader is not None: choices = [x for x in choices if x in loaders_samplers[shared.args.loader]] if len(choices) > 0: - choice = random.choice(choices) - generate_params[choice] = random.choice(params_and_values[cat][choice]) + if cat == 'other': + N = random.randint(1, len(choices)) + maybe_multiple_choices = random.sample(choices, N) + for choice in maybe_multiple_choices: + generate_params[choice] = random.choice(params_and_values[cat][choice]) + else: + choice = random.choice(choices) + generate_params[choice] = random.choice(params_and_values[cat][choice]) + + # If using dynamic temperature, sample the high/low values simultaneously. + # If necessary, resample until the low is lower than the high + if generate_params['dynamic_temperature']: + generate_params['dynamic_temperature_low'] = random.choice(params_and_values['flatten_distribution']['dynamic_temperature_low']) + generate_params['temperature'] = random.choice(params_and_values['flatten_distribution']['temperature']) + while generate_params['dynamic_temperature_low'] >= generate_params['temperature']: + generate_params['dynamic_temperature_low'] = random.choice(params_and_values['flatten_distribution']['dynamic_temperature_low']) + generate_params['temperature'] = random.choice(params_and_values['flatten_distribution']['temperature']) + elif 'dynamic_temperature_low' in generate_params: + generate_params['dynamic_temperature_low'] = defaults['dynamic_temperature_low'] state.update(generate_params) + + diff = {} + # Remove entries that are identical to the defaults + for k in list(generate_params.keys()): + if generate_params[k] != defaults[k]: + diff[k] = generate_params[k] + + logger.info("GENERATED_PRESET=") + pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(diff) + print() + return state, *[generate_params[k] for k in presets_params()]