mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Allow API requests to use parameter presets
This commit is contained in:
parent
8936160e54
commit
474dc7355a
@ -19,6 +19,7 @@ async def run(user_input, history):
|
||||
# Note: the selected defaults change from time to time.
|
||||
request = {
|
||||
'user_input': user_input,
|
||||
'max_new_tokens': 250,
|
||||
'history': history,
|
||||
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
||||
'character': 'Example',
|
||||
@ -32,7 +33,9 @@ async def run(user_input, history):
|
||||
'chat_generation_attempts': 1,
|
||||
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
|
||||
|
||||
'max_new_tokens': 250,
|
||||
# Generation params. If 'preset' is set to different than 'None', the values
|
||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
||||
'preset': 'None',
|
||||
'do_sample': True,
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.1,
|
||||
@ -52,6 +55,7 @@ async def run(user_input, history):
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5,
|
||||
'mirostat_eta': 0.1,
|
||||
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
'truncation_length': 2048,
|
||||
|
@ -13,6 +13,7 @@ URI = f'http://{HOST}/api/v1/chat'
|
||||
def run(user_input, history):
|
||||
request = {
|
||||
'user_input': user_input,
|
||||
'max_new_tokens': 250,
|
||||
'history': history,
|
||||
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
||||
'character': 'Example',
|
||||
@ -26,7 +27,9 @@ def run(user_input, history):
|
||||
'chat_generation_attempts': 1,
|
||||
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
|
||||
|
||||
'max_new_tokens': 250,
|
||||
# Generation params. If 'preset' is set to different than 'None', the values
|
||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
||||
'preset': 'None',
|
||||
'do_sample': True,
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.1,
|
||||
@ -46,6 +49,7 @@ def run(user_input, history):
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5,
|
||||
'mirostat_eta': 0.1,
|
||||
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
'truncation_length': 2048,
|
||||
|
@ -20,8 +20,12 @@ async def run(context):
|
||||
request = {
|
||||
'prompt': context,
|
||||
'max_new_tokens': 250,
|
||||
|
||||
# Generation params. If 'preset' is set to different than 'None', the values
|
||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
||||
'preset': 'None',
|
||||
'do_sample': True,
|
||||
'temperature': 1.3,
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.1,
|
||||
'typical_p': 1,
|
||||
'epsilon_cutoff': 0, # In units of 1e-4
|
||||
@ -39,6 +43,7 @@ async def run(context):
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5,
|
||||
'mirostat_eta': 0.1,
|
||||
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
'truncation_length': 2048,
|
||||
|
@ -12,8 +12,12 @@ def run(prompt):
|
||||
request = {
|
||||
'prompt': prompt,
|
||||
'max_new_tokens': 250,
|
||||
|
||||
# Generation params. If 'preset' is set to different than 'None', the values
|
||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
||||
'preset': 'None',
|
||||
'do_sample': True,
|
||||
'temperature': 1.3,
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.1,
|
||||
'typical_p': 1,
|
||||
'epsilon_cutoff': 0, # In units of 1e-4
|
||||
@ -31,6 +35,7 @@ def run(prompt):
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5,
|
||||
'mirostat_eta': 0.1,
|
||||
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
'truncation_length': 2048,
|
||||
|
@ -5,6 +5,7 @@ from typing import Callable, Optional
|
||||
|
||||
from modules import shared
|
||||
from modules.chat import load_character_memoized
|
||||
from modules.presets import load_preset_memoized
|
||||
|
||||
|
||||
def build_parameters(body, chat=False):
|
||||
@ -40,6 +41,13 @@ def build_parameters(body, chat=False):
|
||||
'stopping_strings': body.get('stopping_strings', []),
|
||||
}
|
||||
|
||||
preset_name = body.get('preset', 'None')
|
||||
if preset_name not in ['None', None, '']:
|
||||
print(preset_name)
|
||||
preset = load_preset_memoized(preset_name)
|
||||
print(preset)
|
||||
generate_params.update(preset)
|
||||
|
||||
if chat:
|
||||
character = body.get('character')
|
||||
instruction_template = body.get('instruction_template')
|
||||
|
55
modules/presets.py
Normal file
55
modules/presets.py
Normal file
@ -0,0 +1,55 @@
|
||||
import functools
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def load_preset(name):
|
||||
generate_params = {
|
||||
'do_sample': True,
|
||||
'temperature': 1,
|
||||
'top_p': 1,
|
||||
'typical_p': 1,
|
||||
'epsilon_cutoff': 0,
|
||||
'eta_cutoff': 0,
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'repetition_penalty': 1,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
'min_length': 0,
|
||||
'length_penalty': 1,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'early_stopping': False,
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5.0,
|
||||
'mirostat_eta': 0.1,
|
||||
}
|
||||
|
||||
with open(Path(f'presets/{name}.yaml'), 'r') as infile:
|
||||
preset = yaml.safe_load(infile)
|
||||
|
||||
for k in preset:
|
||||
generate_params[k] = preset[k]
|
||||
|
||||
generate_params['temperature'] = min(1.99, generate_params['temperature'])
|
||||
return generate_params
|
||||
|
||||
|
||||
@functools.cache
|
||||
def load_preset_memoized(name):
|
||||
return load_preset(name)
|
||||
|
||||
|
||||
def load_preset_for_ui(name, state):
|
||||
generate_params = load_preset(name)
|
||||
state.update(generate_params)
|
||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
|
||||
|
||||
|
||||
def generate_preset_yaml(state):
|
||||
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
|
||||
return yaml.dump(data, sort_keys=False)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from modules import shared
|
||||
@ -41,6 +42,10 @@ def delete_file(fname):
|
||||
logger.info(f'Deleted {fname}.')
|
||||
|
||||
|
||||
def current_time():
|
||||
return f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
|
||||
|
||||
|
||||
def atoi(text):
|
||||
return int(text) if text.isdigit() else text.lower()
|
||||
|
||||
|
60
server.py
60
server.py
@ -33,7 +33,6 @@ import re
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
@ -44,7 +43,7 @@ import yaml
|
||||
from PIL import Image
|
||||
|
||||
import modules.extensions as extensions_module
|
||||
from modules import chat, shared, training, ui, utils
|
||||
from modules import chat, presets, shared, training, ui, utils
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.github import clone_or_pull_repository
|
||||
from modules.html_generator import chat_html_wrapper
|
||||
@ -80,53 +79,6 @@ def load_lora_wrapper(selected_loras):
|
||||
yield ("Successfuly applied the LoRAs")
|
||||
|
||||
|
||||
def load_preset_values(preset_menu, state, return_dict=False):
|
||||
generate_params = {
|
||||
'do_sample': True,
|
||||
'temperature': 1,
|
||||
'top_p': 1,
|
||||
'typical_p': 1,
|
||||
'epsilon_cutoff': 0,
|
||||
'eta_cutoff': 0,
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'repetition_penalty': 1,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
'min_length': 0,
|
||||
'length_penalty': 1,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'early_stopping': False,
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5.0,
|
||||
'mirostat_eta': 0.1,
|
||||
}
|
||||
|
||||
with open(Path(f'presets/{preset_menu}.yaml'), 'r') as infile:
|
||||
preset = yaml.safe_load(infile)
|
||||
|
||||
for k in preset:
|
||||
generate_params[k] = preset[k]
|
||||
|
||||
generate_params['temperature'] = min(1.99, generate_params['temperature'])
|
||||
if return_dict:
|
||||
return generate_params
|
||||
else:
|
||||
state.update(generate_params)
|
||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
|
||||
|
||||
|
||||
def generate_preset_yaml(state):
|
||||
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
|
||||
return yaml.dump(data, sort_keys=False)
|
||||
|
||||
|
||||
def current_time():
|
||||
return f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
|
||||
|
||||
|
||||
def load_prompt(fname):
|
||||
if fname in ['None', '']:
|
||||
return ''
|
||||
@ -251,7 +203,7 @@ def get_model_specific_settings(model):
|
||||
return model_settings
|
||||
|
||||
|
||||
def load_model_specific_settings(model, state, return_dict=False):
|
||||
def load_model_specific_settings(model, state):
|
||||
model_settings = get_model_specific_settings(model)
|
||||
for k in model_settings:
|
||||
if k in state:
|
||||
@ -448,7 +400,7 @@ def create_chat_settings_menus():
|
||||
|
||||
|
||||
def create_settings_menus(default_preset):
|
||||
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
|
||||
generate_params = presets.load_preset(default_preset)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
@ -515,7 +467,7 @@ def create_settings_menus(default_preset):
|
||||
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
|
||||
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
|
||||
|
||||
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']])
|
||||
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']])
|
||||
|
||||
|
||||
def create_file_saving_menus():
|
||||
@ -578,7 +530,7 @@ def create_file_saving_event_handlers():
|
||||
|
||||
shared.gradio['save_preset'].click(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
generate_preset_yaml, shared.gradio['interface_state'], shared.gradio['save_contents']).then(
|
||||
presets.generate_preset_yaml, shared.gradio['interface_state'], shared.gradio['save_contents']).then(
|
||||
lambda: 'presets/', None, shared.gradio['save_root']).then(
|
||||
lambda: 'My Preset.yaml', None, shared.gradio['save_filename']).then(
|
||||
lambda: gr.update(visible=True), None, shared.gradio['file_saver'])
|
||||
@ -1043,7 +995,7 @@ def create_interface():
|
||||
shared.gradio['save_prompt'].click(
|
||||
lambda x: x, shared.gradio['textbox'], shared.gradio['save_contents']).then(
|
||||
lambda: 'prompts/', None, shared.gradio['save_root']).then(
|
||||
lambda: current_time() + '.txt', None, shared.gradio['save_filename']).then(
|
||||
lambda: utils.current_time() + '.txt', None, shared.gradio['save_filename']).then(
|
||||
lambda: gr.update(visible=True), None, shared.gradio['file_saver'])
|
||||
|
||||
shared.gradio['delete_prompt'].click(
|
||||
|
Loading…
Reference in New Issue
Block a user