Allow API requests to use parameter presets

This commit is contained in:
oobabooga 2023-06-13 20:34:35 -03:00
parent 8936160e54
commit 474dc7355a
8 changed files with 96 additions and 58 deletions

View File

@ -19,6 +19,7 @@ async def run(user_input, history):
# Note: the selected defaults change from time to time. # Note: the selected defaults change from time to time.
request = { request = {
'user_input': user_input, 'user_input': user_input,
'max_new_tokens': 250,
'history': history, 'history': history,
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct' 'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': 'Example', 'character': 'Example',
@ -32,7 +33,9 @@ async def run(user_input, history):
'chat_generation_attempts': 1, 'chat_generation_attempts': 1,
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
'max_new_tokens': 250, # Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True, 'do_sample': True,
'temperature': 0.7, 'temperature': 0.7,
'top_p': 0.1, 'top_p': 0.1,
@ -52,6 +55,7 @@ async def run(user_input, history):
'mirostat_mode': 0, 'mirostat_mode': 0,
'mirostat_tau': 5, 'mirostat_tau': 5,
'mirostat_eta': 0.1, 'mirostat_eta': 0.1,
'seed': -1, 'seed': -1,
'add_bos_token': True, 'add_bos_token': True,
'truncation_length': 2048, 'truncation_length': 2048,

View File

@ -13,6 +13,7 @@ URI = f'http://{HOST}/api/v1/chat'
def run(user_input, history): def run(user_input, history):
request = { request = {
'user_input': user_input, 'user_input': user_input,
'max_new_tokens': 250,
'history': history, 'history': history,
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct' 'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': 'Example', 'character': 'Example',
@ -26,7 +27,9 @@ def run(user_input, history):
'chat_generation_attempts': 1, 'chat_generation_attempts': 1,
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
'max_new_tokens': 250, # Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True, 'do_sample': True,
'temperature': 0.7, 'temperature': 0.7,
'top_p': 0.1, 'top_p': 0.1,
@ -46,6 +49,7 @@ def run(user_input, history):
'mirostat_mode': 0, 'mirostat_mode': 0,
'mirostat_tau': 5, 'mirostat_tau': 5,
'mirostat_eta': 0.1, 'mirostat_eta': 0.1,
'seed': -1, 'seed': -1,
'add_bos_token': True, 'add_bos_token': True,
'truncation_length': 2048, 'truncation_length': 2048,

View File

@ -20,8 +20,12 @@ async def run(context):
request = { request = {
'prompt': context, 'prompt': context,
'max_new_tokens': 250, 'max_new_tokens': 250,
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True, 'do_sample': True,
'temperature': 1.3, 'temperature': 0.7,
'top_p': 0.1, 'top_p': 0.1,
'typical_p': 1, 'typical_p': 1,
'epsilon_cutoff': 0, # In units of 1e-4 'epsilon_cutoff': 0, # In units of 1e-4
@ -39,6 +43,7 @@ async def run(context):
'mirostat_mode': 0, 'mirostat_mode': 0,
'mirostat_tau': 5, 'mirostat_tau': 5,
'mirostat_eta': 0.1, 'mirostat_eta': 0.1,
'seed': -1, 'seed': -1,
'add_bos_token': True, 'add_bos_token': True,
'truncation_length': 2048, 'truncation_length': 2048,

View File

@ -12,8 +12,12 @@ def run(prompt):
request = { request = {
'prompt': prompt, 'prompt': prompt,
'max_new_tokens': 250, 'max_new_tokens': 250,
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True, 'do_sample': True,
'temperature': 1.3, 'temperature': 0.7,
'top_p': 0.1, 'top_p': 0.1,
'typical_p': 1, 'typical_p': 1,
'epsilon_cutoff': 0, # In units of 1e-4 'epsilon_cutoff': 0, # In units of 1e-4
@ -31,6 +35,7 @@ def run(prompt):
'mirostat_mode': 0, 'mirostat_mode': 0,
'mirostat_tau': 5, 'mirostat_tau': 5,
'mirostat_eta': 0.1, 'mirostat_eta': 0.1,
'seed': -1, 'seed': -1,
'add_bos_token': True, 'add_bos_token': True,
'truncation_length': 2048, 'truncation_length': 2048,

View File

@ -5,6 +5,7 @@ from typing import Callable, Optional
from modules import shared from modules import shared
from modules.chat import load_character_memoized from modules.chat import load_character_memoized
from modules.presets import load_preset_memoized
def build_parameters(body, chat=False): def build_parameters(body, chat=False):
@ -40,6 +41,13 @@ def build_parameters(body, chat=False):
'stopping_strings': body.get('stopping_strings', []), 'stopping_strings': body.get('stopping_strings', []),
} }
preset_name = body.get('preset', 'None')
if preset_name not in ['None', None, '']:
print(preset_name)
preset = load_preset_memoized(preset_name)
print(preset)
generate_params.update(preset)
if chat: if chat:
character = body.get('character') character = body.get('character')
instruction_template = body.get('instruction_template') instruction_template = body.get('instruction_template')

55
modules/presets.py Normal file
View File

@ -0,0 +1,55 @@
import functools
from pathlib import Path
import yaml
def load_preset(name):
generate_params = {
'do_sample': True,
'temperature': 1,
'top_p': 1,
'typical_p': 1,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1,
'encoder_repetition_penalty': 1,
'top_k': 0,
'num_beams': 1,
'penalty_alpha': 0,
'min_length': 0,
'length_penalty': 1,
'no_repeat_ngram_size': 0,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5.0,
'mirostat_eta': 0.1,
}
with open(Path(f'presets/{name}.yaml'), 'r') as infile:
preset = yaml.safe_load(infile)
for k in preset:
generate_params[k] = preset[k]
generate_params['temperature'] = min(1.99, generate_params['temperature'])
return generate_params
@functools.cache
def load_preset_memoized(name):
return load_preset(name)
def load_preset_for_ui(name, state):
generate_params = load_preset(name)
state.update(generate_params)
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
def generate_preset_yaml(state):
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
return yaml.dump(data, sort_keys=False)

View File

@ -1,5 +1,6 @@
import os import os
import re import re
from datetime import datetime
from pathlib import Path from pathlib import Path
from modules import shared from modules import shared
@ -41,6 +42,10 @@ def delete_file(fname):
logger.info(f'Deleted {fname}.') logger.info(f'Deleted {fname}.')
def current_time():
return f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
def atoi(text): def atoi(text):
return int(text) if text.isdigit() else text.lower() return int(text) if text.isdigit() else text.lower()

View File

@ -33,7 +33,6 @@ import re
import sys import sys
import time import time
import traceback import traceback
from datetime import datetime
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from threading import Lock from threading import Lock
@ -44,7 +43,7 @@ import yaml
from PIL import Image from PIL import Image
import modules.extensions as extensions_module import modules.extensions as extensions_module
from modules import chat, shared, training, ui, utils from modules import chat, presets, shared, training, ui, utils
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.github import clone_or_pull_repository from modules.github import clone_or_pull_repository
from modules.html_generator import chat_html_wrapper from modules.html_generator import chat_html_wrapper
@ -80,53 +79,6 @@ def load_lora_wrapper(selected_loras):
yield ("Successfuly applied the LoRAs") yield ("Successfuly applied the LoRAs")
def load_preset_values(preset_menu, state, return_dict=False):
generate_params = {
'do_sample': True,
'temperature': 1,
'top_p': 1,
'typical_p': 1,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1,
'encoder_repetition_penalty': 1,
'top_k': 0,
'num_beams': 1,
'penalty_alpha': 0,
'min_length': 0,
'length_penalty': 1,
'no_repeat_ngram_size': 0,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5.0,
'mirostat_eta': 0.1,
}
with open(Path(f'presets/{preset_menu}.yaml'), 'r') as infile:
preset = yaml.safe_load(infile)
for k in preset:
generate_params[k] = preset[k]
generate_params['temperature'] = min(1.99, generate_params['temperature'])
if return_dict:
return generate_params
else:
state.update(generate_params)
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
def generate_preset_yaml(state):
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
return yaml.dump(data, sort_keys=False)
def current_time():
return f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
def load_prompt(fname): def load_prompt(fname):
if fname in ['None', '']: if fname in ['None', '']:
return '' return ''
@ -251,7 +203,7 @@ def get_model_specific_settings(model):
return model_settings return model_settings
def load_model_specific_settings(model, state, return_dict=False): def load_model_specific_settings(model, state):
model_settings = get_model_specific_settings(model) model_settings = get_model_specific_settings(model)
for k in model_settings: for k in model_settings:
if k in state: if k in state:
@ -448,7 +400,7 @@ def create_chat_settings_menus():
def create_settings_menus(default_preset): def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True) generate_params = presets.load_preset(default_preset)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
@ -515,7 +467,7 @@ def create_settings_menus(default_preset):
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming') shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]) shared.gradio['preset_menu'].change(presets.load_preset_for_ui, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']])
def create_file_saving_menus(): def create_file_saving_menus():
@ -578,7 +530,7 @@ def create_file_saving_event_handlers():
shared.gradio['save_preset'].click( shared.gradio['save_preset'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_preset_yaml, shared.gradio['interface_state'], shared.gradio['save_contents']).then( presets.generate_preset_yaml, shared.gradio['interface_state'], shared.gradio['save_contents']).then(
lambda: 'presets/', None, shared.gradio['save_root']).then( lambda: 'presets/', None, shared.gradio['save_root']).then(
lambda: 'My Preset.yaml', None, shared.gradio['save_filename']).then( lambda: 'My Preset.yaml', None, shared.gradio['save_filename']).then(
lambda: gr.update(visible=True), None, shared.gradio['file_saver']) lambda: gr.update(visible=True), None, shared.gradio['file_saver'])
@ -1043,7 +995,7 @@ def create_interface():
shared.gradio['save_prompt'].click( shared.gradio['save_prompt'].click(
lambda x: x, shared.gradio['textbox'], shared.gradio['save_contents']).then( lambda x: x, shared.gradio['textbox'], shared.gradio['save_contents']).then(
lambda: 'prompts/', None, shared.gradio['save_root']).then( lambda: 'prompts/', None, shared.gradio['save_root']).then(
lambda: current_time() + '.txt', None, shared.gradio['save_filename']).then( lambda: utils.current_time() + '.txt', None, shared.gradio['save_filename']).then(
lambda: gr.update(visible=True), None, shared.gradio['file_saver']) lambda: gr.update(visible=True), None, shared.gradio['file_saver'])
shared.gradio['delete_prompt'].click( shared.gradio['delete_prompt'].click(