Allow API requests to use parameter presets

This commit is contained in:
oobabooga 2023-06-13 20:34:35 -03:00
parent 8936160e54
commit 474dc7355a
8 changed files with 96 additions and 58 deletions

View File

@ -19,6 +19,7 @@ async def run(user_input, history):
# Note: the selected defaults change from time to time.
request = {
'user_input': user_input,
'max_new_tokens': 250,
'history': history,
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': 'Example',
@ -32,7 +33,9 @@ async def run(user_input, history):
'chat_generation_attempts': 1,
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
'max_new_tokens': 250,
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
@ -52,6 +55,7 @@ async def run(user_input, history):
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,

View File

@ -13,6 +13,7 @@ URI = f'http://{HOST}/api/v1/chat'
def run(user_input, history):
request = {
'user_input': user_input,
'max_new_tokens': 250,
'history': history,
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': 'Example',
@ -26,7 +27,9 @@ def run(user_input, history):
'chat_generation_attempts': 1,
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
'max_new_tokens': 250,
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
@ -46,6 +49,7 @@ def run(user_input, history):
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,

View File

@ -20,8 +20,12 @@ async def run(context):
request = {
'prompt': context,
'max_new_tokens': 250,
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True,
'temperature': 1.3,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
'epsilon_cutoff': 0, # In units of 1e-4
@ -39,6 +43,7 @@ async def run(context):
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,

View File

@ -12,8 +12,12 @@ def run(prompt):
request = {
'prompt': prompt,
'max_new_tokens': 250,
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True,
'temperature': 1.3,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
'epsilon_cutoff': 0, # In units of 1e-4
@ -31,6 +35,7 @@ def run(prompt):
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,

View File

@ -5,6 +5,7 @@ from typing import Callable, Optional
from modules import shared
from modules.chat import load_character_memoized
from modules.presets import load_preset_memoized
def build_parameters(body, chat=False):
@ -40,6 +41,13 @@ def build_parameters(body, chat=False):
'stopping_strings': body.get('stopping_strings', []),
}
preset_name = body.get('preset', 'None')
if preset_name not in ['None', None, '']:
print(preset_name)
preset = load_preset_memoized(preset_name)
print(preset)
generate_params.update(preset)
if chat:
character = body.get('character')
instruction_template = body.get('instruction_template')

55
modules/presets.py Normal file
View File

@ -0,0 +1,55 @@
import functools
from pathlib import Path
import yaml
def load_preset(name):
generate_params = {
'do_sample': True,
'temperature': 1,
'top_p': 1,
'typical_p': 1,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1,
'encoder_repetition_penalty': 1,
'top_k': 0,
'num_beams': 1,
'penalty_alpha': 0,
'min_length': 0,
'length_penalty': 1,
'no_repeat_ngram_size': 0,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5.0,
'mirostat_eta': 0.1,
}
with open(Path(f'presets/{name}.yaml'), 'r') as infile:
preset = yaml.safe_load(infile)
for k in preset:
generate_params[k] = preset[k]
generate_params['temperature'] = min(1.99, generate_params['temperature'])
return generate_params
@functools.cache
def load_preset_memoized(name):
return load_preset(name)
def load_preset_for_ui(name, state):
generate_params = load_preset(name)
state.update(generate_params)
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
def generate_preset_yaml(state):
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
return yaml.dump(data, sort_keys=False)

View File

@ -1,5 +1,6 @@
import os
import re
from datetime import datetime
from pathlib import Path
from modules import shared
@ -41,6 +42,10 @@ def delete_file(fname):
logger.info(f'Deleted {fname}.')
def current_time():
return f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
def atoi(text):
return int(text) if text.isdigit() else text.lower()

View File

@ -33,7 +33,6 @@ import re
import sys
import time
import traceback
from datetime import datetime
from functools import partial
from pathlib import Path
from threading import Lock
@ -44,7 +43,7 @@ import yaml
from PIL import Image
import modules.extensions as extensions_module
from modules import chat, shared, training, ui, utils
from modules import chat, presets, shared, training, ui, utils
from modules.extensions import apply_extensions
from modules.github import clone_or_pull_repository
from modules.html_generator import chat_html_wrapper
@ -80,53 +79,6 @@ def load_lora_wrapper(selected_loras):
yield ("Successfuly applied the LoRAs")
def load_preset_values(preset_menu, state, return_dict=False):
generate_params = {
'do_sample': True,
'temperature': 1,
'top_p': 1,
'typical_p': 1,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1,
'encoder_repetition_penalty': 1,
'top_k': 0,
'num_beams': 1,
'penalty_alpha': 0,
'min_length': 0,
'length_penalty': 1,
'no_repeat_ngram_size': 0,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5.0,
'mirostat_eta': 0.1,
}
with open(Path(f'presets/{preset_menu}.yaml'), 'r') as infile:
preset = yaml.safe_load(infile)
for k in preset:
generate_params[k] = preset[k]
generate_params['temperature'] = min(1.99, generate_params['temperature'])
if return_dict:
return generate_params
else:
state.update(generate_params)
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
def generate_preset_yaml(state):
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
return yaml.dump(data, sort_keys=False)
def current_time():
return f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
def load_prompt(fname):
if fname in ['None', '']:
return ''
@ -251,7 +203,7 @@ def get_model_specific_settings(model):
return model_settings
def load_model_specific_settings(model, state, return_dict=False):
def load_model_specific_settings(model, state):
model_settings = get_model_specific_settings(model)
for k in model_settings:
if k in state:
@ -448,7 +400,7 @@ def create_chat_settings_menus():
def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
generate_params = presets.load_preset(default_preset)
with gr.Row():
with gr.Column():
with gr.Row():
@ -515,7 +467,7 @@ def create_settings_menus(default_preset):
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']])
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']])
def create_file_saving_menus():
@ -578,7 +530,7 @@ def create_file_saving_event_handlers():
shared.gradio['save_preset'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_preset_yaml, shared.gradio['interface_state'], shared.gradio['save_contents']).then(
presets.generate_preset_yaml, shared.gradio['interface_state'], shared.gradio['save_contents']).then(
lambda: 'presets/', None, shared.gradio['save_root']).then(
lambda: 'My Preset.yaml', None, shared.gradio['save_filename']).then(
lambda: gr.update(visible=True), None, shared.gradio['file_saver'])
@ -1043,7 +995,7 @@ def create_interface():
shared.gradio['save_prompt'].click(
lambda x: x, shared.gradio['textbox'], shared.gradio['save_contents']).then(
lambda: 'prompts/', None, shared.gradio['save_root']).then(
lambda: current_time() + '.txt', None, shared.gradio['save_filename']).then(
lambda: utils.current_time() + '.txt', None, shared.gradio['save_filename']).then(
lambda: gr.update(visible=True), None, shared.gradio['file_saver'])
shared.gradio['delete_prompt'].click(