mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
commit
884871c107
@ -127,10 +127,23 @@ class ModelDownloader:
|
||||
if classifications[i] in ['pytorch', 'pt']:
|
||||
links.pop(i)
|
||||
|
||||
# For GGUF, try to download only the Q4_K_M if no specific file is specified.
|
||||
# If not present, exclude all GGUFs, as that's likely a repository with both
|
||||
# GGUF and fp16 files.
|
||||
if has_gguf and specific_file is None:
|
||||
has_q4km = False
|
||||
for i in range(len(classifications) - 1, -1, -1):
|
||||
if 'q4_k_m' not in links[i].lower():
|
||||
links.pop(i)
|
||||
if 'q4_k_m' in links[i].lower():
|
||||
has_q4km = True
|
||||
|
||||
if has_q4km:
|
||||
for i in range(len(classifications) - 1, -1, -1):
|
||||
if 'q4_k_m' not in links[i].lower():
|
||||
links.pop(i)
|
||||
else:
|
||||
for i in range(len(classifications) - 1, -1, -1):
|
||||
if links[i].lower().endswith('.gguf'):
|
||||
links.pop(i)
|
||||
|
||||
is_llamacpp = has_gguf and specific_file is not None
|
||||
return links, sha256, is_lora, is_llamacpp
|
||||
|
@ -236,7 +236,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False) -
|
||||
|
||||
max_tokens = generate_params['max_new_tokens']
|
||||
if max_tokens in [None, 0]:
|
||||
generate_params['max_new_tokens'] = 200
|
||||
generate_params['max_new_tokens'] = 512
|
||||
generate_params['auto_max_new_tokens'] = True
|
||||
|
||||
requested_model = generate_params.pop('model')
|
||||
|
@ -10,7 +10,7 @@ class GenerationOptions(BaseModel):
|
||||
min_p: float = 0
|
||||
top_k: int = 0
|
||||
repetition_penalty: float = 1
|
||||
repetition_penalty_range: int = 0
|
||||
repetition_penalty_range: int = 1024
|
||||
typical_p: float = 1
|
||||
tfs: float = 1
|
||||
top_a: float = 0
|
||||
|
@ -165,10 +165,19 @@ class ExllamaModel:
|
||||
if has_leading_space:
|
||||
decoded_text = ' ' + decoded_text
|
||||
|
||||
yield decoded_text
|
||||
# Check the partial unicode character
|
||||
if chr(0xfffd) in decoded_text:
|
||||
is_last = i == max_new_tokens - 1
|
||||
is_stopping = token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything
|
||||
# If we are not at the end of the generation, we skip this token
|
||||
if not (is_last or is_stopping):
|
||||
continue
|
||||
|
||||
if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything:
|
||||
break
|
||||
|
||||
yield decoded_text
|
||||
|
||||
# Case 2: CFG
|
||||
# Copied from https://github.com/turboderp/exllama/blob/master/example_cfg.py
|
||||
else:
|
||||
@ -205,6 +214,14 @@ class ExllamaModel:
|
||||
if has_leading_space:
|
||||
decoded_text = ' ' + decoded_text
|
||||
|
||||
# Check the partial unicode character
|
||||
if chr(0xfffd) in decoded_text:
|
||||
is_last = i == max_new_tokens - 1
|
||||
is_stopping = token.item() == self.tokenizer.eos_token_id or shared.stop_everything
|
||||
# If we are not at the end of the generation, we skip this token
|
||||
if not (is_last or is_stopping):
|
||||
continue
|
||||
|
||||
yield decoded_text
|
||||
if token.item() == self.tokenizer.eos_token_id or shared.stop_everything:
|
||||
break
|
||||
|
@ -138,11 +138,19 @@ class Exllamav2Model:
|
||||
if has_leading_space:
|
||||
decoded_text = ' ' + decoded_text
|
||||
|
||||
yield decoded_text
|
||||
# Check the partial unicode character
|
||||
if chr(0xfffd) in decoded_text:
|
||||
is_last = i == max_new_tokens - 1
|
||||
is_stopping = token.item() == self.tokenizer.eos_token_id or shared.stop_everything
|
||||
# If we are not at the end of the generation, we skip this token
|
||||
if not (is_last or is_stopping):
|
||||
continue
|
||||
|
||||
if token.item() == self.tokenizer.eos_token_id or shared.stop_everything:
|
||||
break
|
||||
|
||||
yield decoded_text
|
||||
|
||||
def generate(self, prompt, state):
|
||||
output = ''
|
||||
for output in self.generate_with_streaming(prompt, state):
|
||||
|
@ -143,6 +143,11 @@ loaders_and_params = OrderedDict({
|
||||
'no_mmap',
|
||||
'mlock'
|
||||
],
|
||||
'QuIP#': [
|
||||
'trust_remote_code',
|
||||
'no_use_fast',
|
||||
'no_flash_attn',
|
||||
]
|
||||
})
|
||||
|
||||
loaders_samplers = {
|
||||
@ -453,6 +458,43 @@ loaders_samplers = {
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
},
|
||||
'QuIP#': {
|
||||
'temperature',
|
||||
'temperature_last',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'top_k',
|
||||
'typical_p',
|
||||
'epsilon_cutoff',
|
||||
'eta_cutoff',
|
||||
'tfs',
|
||||
'top_a',
|
||||
'repetition_penalty',
|
||||
'presence_penalty',
|
||||
'frequency_penalty',
|
||||
'repetition_penalty_range',
|
||||
'encoder_repetition_penalty',
|
||||
'no_repeat_ngram_size',
|
||||
'min_length',
|
||||
'seed',
|
||||
'do_sample',
|
||||
'penalty_alpha',
|
||||
'num_beams',
|
||||
'length_penalty',
|
||||
'early_stopping',
|
||||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'grammar_file_row',
|
||||
'grammar_string',
|
||||
'guidance_scale',
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
'custom_token_bans',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
},
|
||||
}
|
||||
|
||||
loaders_model_types = {
|
||||
|
@ -1,4 +1,5 @@
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
@ -23,6 +24,7 @@ import modules.shared as shared
|
||||
from modules import RoPE, llama_attn_hijack, sampler_hijack
|
||||
from modules.logging_colors import logger
|
||||
from modules.models_settings import get_model_metadata
|
||||
from modules.relative_imports import RelativeImport
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
@ -69,6 +71,7 @@ def load_model(model_name, loader=None):
|
||||
'ExLlamav2_HF': ExLlamav2_HF_loader,
|
||||
'ctransformers': ctransformers_loader,
|
||||
'AutoAWQ': AutoAWQ_loader,
|
||||
'QuIP#': QuipSharp_loader,
|
||||
}
|
||||
|
||||
metadata = get_model_metadata(model_name)
|
||||
@ -321,6 +324,37 @@ def AutoAWQ_loader(model_name):
|
||||
return model
|
||||
|
||||
|
||||
def QuipSharp_loader(model_name):
|
||||
try:
|
||||
with RelativeImport("repositories/quip-sharp"):
|
||||
from lib.utils.unsafe_import import model_from_hf_path
|
||||
except:
|
||||
logger.error(
|
||||
"\nQuIP# has not been found. It must be installed manually for now.\n"
|
||||
"For instructions on how to do that, please consult:\n"
|
||||
"https://github.com/oobabooga/text-generation-webui/pull/4803\n"
|
||||
)
|
||||
return None, None
|
||||
|
||||
# This fixes duplicate logging messages after the import above.
|
||||
handlers = logging.getLogger().handlers
|
||||
if len(handlers) > 1:
|
||||
logging.getLogger().removeHandler(handlers[1])
|
||||
|
||||
model_dir = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
if not all((model_dir / file).exists() for file in ['tokenizer_config.json', 'special_tokens_map.json', 'tokenizer.model']):
|
||||
logger.error(f"Could not load the model because the tokenizer files could not be found in the model folder. Please download the following files from the original (unquantized) model into {model_dir}: special_tokens_map.json, tokenizer.json, tokenizer.model, tokenizer_config.json.")
|
||||
return None, None
|
||||
|
||||
model, model_str = model_from_hf_path(
|
||||
model_dir,
|
||||
use_cuda_graph=False,
|
||||
use_flash_attn=not shared.args.no_flash_attn
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def GPTQ_loader(model_name):
|
||||
|
||||
# Monkey patch
|
||||
|
@ -33,14 +33,24 @@ def get_model_metadata(model):
|
||||
for k in settings[pat]:
|
||||
model_settings[k] = settings[pat][k]
|
||||
|
||||
|
||||
path = Path(f'{shared.args.model_dir}/{model}/config.json')
|
||||
if path.exists():
|
||||
hf_metadata = json.loads(open(path, 'r').read())
|
||||
else:
|
||||
hf_metadata = None
|
||||
|
||||
if 'loader' not in model_settings:
|
||||
loader = infer_loader(model, model_settings)
|
||||
if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
|
||||
loader = 'AutoGPTQ'
|
||||
if hf_metadata is not None and 'quip_params' in hf_metadata:
|
||||
model_settings['loader'] = 'QuIP#'
|
||||
else:
|
||||
loader = infer_loader(model, model_settings)
|
||||
if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
|
||||
loader = 'AutoGPTQ'
|
||||
|
||||
model_settings['loader'] = loader
|
||||
model_settings['loader'] = loader
|
||||
|
||||
# Read GGUF metadata
|
||||
# GGUF metadata
|
||||
if model_settings['loader'] in ['llama.cpp', 'llamacpp_HF', 'ctransformers']:
|
||||
path = Path(f'{shared.args.model_dir}/{model}')
|
||||
if path.is_file():
|
||||
@ -57,9 +67,8 @@ def get_model_metadata(model):
|
||||
model_settings['rope_freq_base'] = metadata['llama.rope.freq_base']
|
||||
|
||||
else:
|
||||
# Read transformers metadata
|
||||
path = Path(f'{shared.args.model_dir}/{model}/config.json')
|
||||
if path.exists():
|
||||
# Transformers metadata
|
||||
if hf_metadata is not None:
|
||||
metadata = json.loads(open(path, 'r').read())
|
||||
if 'max_position_embeddings' in metadata:
|
||||
model_settings['truncation_length'] = metadata['max_position_embeddings']
|
||||
|
@ -18,7 +18,7 @@ def default_preset():
|
||||
'repetition_penalty': 1,
|
||||
'presence_penalty': 0,
|
||||
'frequency_penalty': 0,
|
||||
'repetition_penalty_range': 0,
|
||||
'repetition_penalty_range': 1024,
|
||||
'typical_p': 1,
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
|
@ -36,7 +36,7 @@ settings = {
|
||||
'prompt-default': 'QA',
|
||||
'prompt-notebook': 'QA',
|
||||
'preset': 'simple-1',
|
||||
'max_new_tokens': 200,
|
||||
'max_new_tokens': 512,
|
||||
'max_new_tokens_min': 1,
|
||||
'max_new_tokens_max': 4096,
|
||||
'negative_prompt': '',
|
||||
@ -241,6 +241,8 @@ def fix_loader_name(name):
|
||||
return 'ctransformers'
|
||||
elif name in ['autoawq', 'awq', 'auto-awq']:
|
||||
return 'AutoAWQ'
|
||||
elif name in ['quip#', 'quip-sharp', 'quipsharp', 'quip_sharp']:
|
||||
return 'QuIP#'
|
||||
|
||||
|
||||
def add_extension(name, last=False):
|
||||
|
@ -264,14 +264,10 @@ def apply_stopping_strings(reply, all_stop_strings):
|
||||
|
||||
|
||||
def get_reply_from_output_ids(output_ids, state, starting_from=0):
|
||||
if shared.is_seq2seq:
|
||||
reply = decode(output_ids, state['skip_special_tokens'])
|
||||
else:
|
||||
reply = decode(output_ids[starting_from:], state['skip_special_tokens'])
|
||||
# Prevent LlamaTokenizer from skipping a space
|
||||
if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > 0:
|
||||
if shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])).startswith('▁'):
|
||||
reply = ' ' + reply
|
||||
reply = decode(output_ids[starting_from:], state['skip_special_tokens'])
|
||||
if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > starting_from:
|
||||
if shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])).startswith('▁'):
|
||||
reply = ' ' + reply
|
||||
|
||||
return reply
|
||||
|
||||
@ -343,7 +339,8 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
||||
if cuda:
|
||||
output = output.cuda()
|
||||
|
||||
yield get_reply_from_output_ids(output, state, starting_from=len(input_ids[0]))
|
||||
starting_from = 0 if shared.is_seq2seq else len(input_ids[0])
|
||||
yield get_reply_from_output_ids(output, state, starting_from=starting_from)
|
||||
|
||||
# Stream the reply 1 token at a time.
|
||||
# This is based on the trick of using 'stopping_criteria' to create an iterator.
|
||||
@ -360,12 +357,17 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
||||
|
||||
with generate_with_streaming(**generate_params) as generator:
|
||||
cumulative_reply = ''
|
||||
starting_from = len(input_ids[0])
|
||||
starting_from = 0 if shared.is_seq2seq else len(input_ids[0])
|
||||
for output in generator:
|
||||
if output[-1] in eos_token_ids:
|
||||
break
|
||||
|
||||
cumulative_reply += get_reply_from_output_ids(output, state, starting_from=starting_from)
|
||||
new_content = get_reply_from_output_ids(output, state, starting_from=starting_from)
|
||||
# check the partial unicode character
|
||||
if chr(0xfffd) in new_content:
|
||||
continue
|
||||
|
||||
cumulative_reply += new_content
|
||||
starting_from = len(output)
|
||||
yield cumulative_reply
|
||||
|
||||
|
10
one_click.py
10
one_click.py
@ -4,6 +4,7 @@ import hashlib
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import signal
|
||||
import site
|
||||
import subprocess
|
||||
import sys
|
||||
@ -27,6 +28,13 @@ else:
|
||||
flags = f"{' '.join([flag for flag in sys.argv[1:] if flag != '--update'])} {CMD_FLAGS}"
|
||||
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
|
||||
def is_linux():
|
||||
return sys.platform.startswith("linux")
|
||||
|
||||
@ -210,7 +218,7 @@ def install_webui():
|
||||
elif is_linux() and (choice == "C" or choice == "N"):
|
||||
install_pytorch = "python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu"
|
||||
elif choice == "D":
|
||||
install_pytorch = "python -m pip install torch==2.0.1a0 torchvision==0.15.2a0 intel_extension_for_pytorch==2.0.110+xpu -f https://developer.intel.com/ipex-whl-stable-xpu"
|
||||
install_pytorch = "python -m pip install torch==2.0.1a0 torchvision==0.15.2a0 intel_extension_for_pytorch==2.0.110+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
|
||||
|
||||
# Install Git and then Pytorch
|
||||
run_cmd(f"{install_git} && {install_pytorch} && python -m pip install py-cpuinfo==9.0.0", assert_success=True, environment=True)
|
||||
|
12
server.py
12
server.py
@ -21,6 +21,7 @@ matplotlib.use('Agg') # This fixes LaTeX rendering on some systems
|
||||
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from functools import partial
|
||||
@ -55,6 +56,17 @@ from modules.models_settings import (
|
||||
from modules.utils import gradio
|
||||
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logger.info("Received Ctrl+C. Shutting down Text generation web UI gracefully.")
|
||||
if 'interface' in shared.gradio:
|
||||
shared.gradio['interface'].close()
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
|
||||
def create_interface():
|
||||
|
||||
title = 'Text generation web UI'
|
||||
|
@ -6,7 +6,7 @@ chat_style: cai-chat
|
||||
prompt-default: QA
|
||||
prompt-notebook: QA
|
||||
preset: simple-1
|
||||
max_new_tokens: 200
|
||||
max_new_tokens: 512
|
||||
max_new_tokens_min: 1
|
||||
max_new_tokens_max: 4096
|
||||
seed: -1
|
||||
|
Loading…
Reference in New Issue
Block a user