Merge branch 'dev' into style_improvements

This commit is contained in:
Lounger 2023-12-20 23:09:15 +01:00
commit a098c7eee3
27 changed files with 217 additions and 224 deletions

View File

@ -252,6 +252,7 @@ List of command-line flags
| Flag | Description | | Flag | Description |
|-------------|-------------| |-------------|-------------|
| `--tensorcores` | Use llama-cpp-python compiled with tensor cores support. This increases performance on RTX cards. NVIDIA only. |
| `--n_ctx N_CTX` | Size of the prompt context. | | `--n_ctx N_CTX` | Size of the prompt context. |
| `--threads` | Number of threads to use. | | `--threads` | Number of threads to use. |
| `--threads-batch THREADS_BATCH` | Number of threads to use for batches/prompt processing. | | `--threads-batch THREADS_BATCH` | Number of threads to use for batches/prompt processing. |

View File

@ -6,27 +6,13 @@ import time
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
from TTS.api import TTS
from TTS.utils.synthesizer import Synthesizer
from modules import chat, shared, ui_chat from modules import chat, shared, ui_chat
from modules.logging_colors import logger
from modules.ui import create_refresh_button from modules.ui import create_refresh_button
from modules.utils import gradio from modules.utils import gradio
try:
from TTS.api import TTS
from TTS.utils.synthesizer import Synthesizer
except ModuleNotFoundError:
logger.error(
"Could not find the TTS module. Make sure to install the requirements for the coqui_tts extension."
"\n"
"\nLinux / Mac:\npip install -r extensions/coqui_tts/requirements.txt\n"
"\nWindows:\npip install -r extensions\\coqui_tts\\requirements.txt\n"
"\n"
"If you used the one-click installer, paste the command above in the terminal window launched after running the \"cmd_\" script. On Windows, that's \"cmd_windows.bat\"."
)
raise
os.environ["COQUI_TOS_AGREED"] = "1" os.environ["COQUI_TOS_AGREED"] = "1"
params = { params = {

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import json import json
import logging
import os import os
import traceback import traceback
from threading import Thread from threading import Thread
@ -367,6 +368,7 @@ def run_server():
if shared.args.admin_key and shared.args.admin_key != shared.args.api_key: if shared.args.admin_key and shared.args.admin_key != shared.args.api_key:
logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n') logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n')
logging.getLogger("uvicorn.error").propagate = False
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile) uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)

View File

@ -126,7 +126,7 @@ def load_quantized(model_name):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}') path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
pt_path = find_quantized_model_file(model_name) pt_path = find_quantized_model_file(model_name)
if not pt_path: if not pt_path:
logger.error("Could not find the quantized model in .pt or .safetensors format, exiting...") logger.error("Could not find the quantized model in .pt or .safetensors format. Exiting.")
exit() exit()
else: else:
logger.info(f"Found the following quantized model: {pt_path}") logger.info(f"Found the following quantized model: {pt_path}")

View File

@ -53,7 +53,10 @@ def add_lora_exllama(lora_names):
lora_path = get_lora_path(lora_names[0]) lora_path = get_lora_path(lora_names[0])
lora_config_path = lora_path / "adapter_config.json" lora_config_path = lora_path / "adapter_config.json"
lora_adapter_path = lora_path / "adapter_model.bin" for file_name in ["adapter_model.safetensors", "adapter_model.bin"]:
file_path = lora_path / file_name
if file_path.is_file():
lora_adapter_path = file_path
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join([lora_names[0]]))) logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join([lora_names[0]])))
if shared.model.__class__.__name__ == 'ExllamaModel': if shared.model.__class__.__name__ == 'ExllamaModel':
@ -138,7 +141,7 @@ def add_lora_transformers(lora_names):
# Add a LoRA when another LoRA is already present # Add a LoRA when another LoRA is already present
if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys(): if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys():
logger.info(f"Adding the LoRA(s) named {added_set} to the model...") logger.info(f"Adding the LoRA(s) named {added_set} to the model")
for lora in added_set: for lora in added_set:
shared.model.load_adapter(get_lora_path(lora), lora) shared.model.load_adapter(get_lora_path(lora), lora)

View File

@ -95,7 +95,8 @@ def generate_chat_prompt(user_input, state, **kwargs):
else: else:
renderer = chat_renderer renderer = chat_renderer
if state['context'].strip() != '': if state['context'].strip() != '':
messages.append({"role": "system", "content": state['context']}) context = replace_character_names(state['context'], state['name1'], state['name2'])
messages.append({"role": "system", "content": context})
insert_pos = len(messages) insert_pos = len(messages)
for user_msg, assistant_msg in reversed(history): for user_msg, assistant_msg in reversed(history):
@ -768,13 +769,13 @@ def delete_character(name, instruct=False):
def jinja_template_from_old_format(params, verbose=False): def jinja_template_from_old_format(params, verbose=False):
MASTER_TEMPLATE = """ MASTER_TEMPLATE = """
{%- set found_item = false -%} {%- set ns = namespace(found=false) -%}
{%- for message in messages -%} {%- for message in messages -%}
{%- if message['role'] == 'system' -%} {%- if message['role'] == 'system' -%}
{%- set found_item = true -%} {%- set ns.found = true -%}
{%- endif -%} {%- endif -%}
{%- endfor -%} {%- endfor -%}
{%- if not found_item -%} {%- if not ns.found -%}
{{- '<|PRE-SYSTEM|>' + '<|SYSTEM-MESSAGE|>' + '<|POST-SYSTEM|>' -}} {{- '<|PRE-SYSTEM|>' + '<|SYSTEM-MESSAGE|>' + '<|POST-SYSTEM|>' -}}
{%- endif %} {%- endif %}
{%- for message in messages %} {%- for message in messages %}

View File

@ -1,4 +1,3 @@
import random
import traceback import traceback
from pathlib import Path from pathlib import Path
@ -10,7 +9,7 @@ from exllamav2 import (
ExLlamaV2Config, ExLlamaV2Config,
ExLlamaV2Tokenizer ExLlamaV2Tokenizer
) )
from exllamav2.generator import ExLlamaV2BaseGenerator, ExLlamaV2Sampler from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator
from modules import shared from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
@ -64,7 +63,7 @@ class Exllamav2Model:
else: else:
cache = ExLlamaV2Cache(model) cache = ExLlamaV2Cache(model)
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer) generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
result = self() result = self()
result.model = model result.model = model
@ -115,41 +114,21 @@ class Exllamav2Model:
ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True) ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True)
ids = ids[:, -get_max_prompt_length(state):] ids = ids[:, -get_max_prompt_length(state):]
initial_len = ids.shape[-1]
if state['auto_max_new_tokens']: if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - ids.shape[-1] max_new_tokens = state['truncation_length'] - ids.shape[-1]
else: else:
max_new_tokens = state['max_new_tokens'] max_new_tokens = state['max_new_tokens']
# _gen_begin_base self.generator.begin_stream(ids, settings, loras=self.loras)
self.cache.current_seq_len = 0
self.model.forward(ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras)
has_leading_space = False decoded_text = ''
for i in range(max_new_tokens): for i in range(max_new_tokens):
logits = self.model.forward(ids[:, -1:], self.cache, input_mask=None, loras=self.loras).float().cpu() chunk, eos, _ = self.generator.stream()
token, _, _ = ExLlamaV2Sampler.sample(logits, settings, ids, random.random(), self.tokenizer) if eos or shared.stop_everything:
ids = torch.cat([ids, token], dim=1)
if i == 0 and self.tokenizer.tokenizer.id_to_piece(int(token)).startswith(''):
has_leading_space = True
decoded_text = self.tokenizer.decode(ids[:, initial_len:], decode_special_tokens=not state['skip_special_tokens'])[0]
if has_leading_space:
decoded_text = ' ' + decoded_text
# Check the partial unicode character
if chr(0xfffd) in decoded_text:
is_last = i == max_new_tokens - 1
is_stopping = token.item() == self.tokenizer.eos_token_id or shared.stop_everything
# If we are not at the end of the generation, we skip this token
if not (is_last or is_stopping):
continue
if token.item() == self.tokenizer.eos_token_id or shared.stop_everything:
break break
decoded_text += chunk
yield decoded_text yield decoded_text
def generate(self, prompt, state): def generate(self, prompt, state):

View File

@ -31,9 +31,14 @@ def load_extensions():
for i, name in enumerate(shared.args.extensions): for i, name in enumerate(shared.args.extensions):
if name in available_extensions: if name in available_extensions:
if name != 'api': if name != 'api':
logger.info(f'Loading the extension "{name}"...') logger.info(f'Loading the extension "{name}"')
try: try:
exec(f"import extensions.{name}.script") try:
exec(f"import extensions.{name}.script")
except ModuleNotFoundError:
logger.error(f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\nLinux / Mac:\n\npip install -r extensions/{name}/requirements.txt --upgrade\n\nWindows:\n\npip install -r extensions\\{name}\\requirements.txt --upgrade\n\nIf you used the one-click installer, paste the command above in the terminal window opened after launching the cmd script for your OS.")
raise
extension = getattr(extensions, name).script extension = getattr(extensions, name).script
apply_settings(extension, name) apply_settings(extension, name)
if extension not in setup_called and hasattr(extension, "setup"): if extension not in setup_called and hasattr(extension, "setup"):

View File

@ -20,12 +20,21 @@ try:
except: except:
llama_cpp_cuda = None llama_cpp_cuda = None
try:
import llama_cpp_cuda_tensorcores
except:
llama_cpp_cuda_tensorcores = None
def llama_cpp_lib(): def llama_cpp_lib():
if (shared.args.cpu and llama_cpp is not None) or llama_cpp_cuda is None: if shared.args.cpu and llama_cpp is not None:
return llama_cpp return llama_cpp
else: elif shared.args.tensorcores and llama_cpp_cuda_tensorcores is not None:
return llama_cpp_cuda_tensorcores
elif llama_cpp_cuda is not None:
return llama_cpp_cuda return llama_cpp_cuda
else:
return llama_cpp
class LlamacppHF(PreTrainedModel): class LlamacppHF(PreTrainedModel):

View File

@ -19,12 +19,21 @@ try:
except: except:
llama_cpp_cuda = None llama_cpp_cuda = None
try:
import llama_cpp_cuda_tensorcores
except:
llama_cpp_cuda_tensorcores = None
def llama_cpp_lib(): def llama_cpp_lib():
if (shared.args.cpu and llama_cpp is not None) or llama_cpp_cuda is None: if shared.args.cpu and llama_cpp is not None:
return llama_cpp return llama_cpp
else: elif shared.args.tensorcores and llama_cpp_cuda_tensorcores is not None:
return llama_cpp_cuda_tensorcores
elif llama_cpp_cuda is not None:
return llama_cpp_cuda return llama_cpp_cuda
else:
return llama_cpp
def ban_eos_logits_processor(eos_token, input_ids, logits): def ban_eos_logits_processor(eos_token, input_ids, logits):

View File

@ -43,7 +43,8 @@ loaders_and_params = OrderedDict({
'compress_pos_emb', 'compress_pos_emb',
'cpu', 'cpu',
'numa', 'numa',
'no_offload_kqv' 'no_offload_kqv',
'tensorcores',
], ],
'llamacpp_HF': [ 'llamacpp_HF': [
'n_ctx', 'n_ctx',
@ -65,6 +66,7 @@ loaders_and_params = OrderedDict({
'no_use_fast', 'no_use_fast',
'logits_all', 'logits_all',
'no_offload_kqv', 'no_offload_kqv',
'tensorcores',
'llamacpp_HF_info', 'llamacpp_HF_info',
], ],
'ExLlamav2_HF': [ 'ExLlamav2_HF': [

View File

@ -1,117 +1,67 @@
# Copied from https://stackoverflow.com/a/1336640
import logging import logging
import platform
logging.basicConfig(
format='%(asctime)s %(levelname)s:%(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
def add_coloring_to_emit_windows(fn):
# add methods we need to the class
def _out_handle(self):
import ctypes
return ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
out_handle = property(_out_handle)
def _set_color(self, code):
import ctypes
# Constants from the Windows API
self.STD_OUTPUT_HANDLE = -11
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
ctypes.windll.kernel32.SetConsoleTextAttribute(hdl, code)
setattr(logging.StreamHandler, '_set_color', _set_color)
def new(*args):
FOREGROUND_BLUE = 0x0001 # text color contains blue.
FOREGROUND_GREEN = 0x0002 # text color contains green.
FOREGROUND_RED = 0x0004 # text color contains red.
FOREGROUND_INTENSITY = 0x0008 # text color is intensified.
FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED
# winbase.h
# STD_INPUT_HANDLE = -10
# STD_OUTPUT_HANDLE = -11
# STD_ERROR_HANDLE = -12
# wincon.h
# FOREGROUND_BLACK = 0x0000
FOREGROUND_BLUE = 0x0001
FOREGROUND_GREEN = 0x0002
# FOREGROUND_CYAN = 0x0003
FOREGROUND_RED = 0x0004
FOREGROUND_MAGENTA = 0x0005
FOREGROUND_YELLOW = 0x0006
# FOREGROUND_GREY = 0x0007
FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified.
# BACKGROUND_BLACK = 0x0000
# BACKGROUND_BLUE = 0x0010
# BACKGROUND_GREEN = 0x0020
# BACKGROUND_CYAN = 0x0030
# BACKGROUND_RED = 0x0040
# BACKGROUND_MAGENTA = 0x0050
BACKGROUND_YELLOW = 0x0060
# BACKGROUND_GREY = 0x0070
BACKGROUND_INTENSITY = 0x0080 # background color is intensified.
levelno = args[1].levelno
if (levelno >= 50):
color = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY
elif (levelno >= 40):
color = FOREGROUND_RED | FOREGROUND_INTENSITY
elif (levelno >= 30):
color = FOREGROUND_YELLOW | FOREGROUND_INTENSITY
elif (levelno >= 20):
color = FOREGROUND_GREEN
elif (levelno >= 10):
color = FOREGROUND_MAGENTA
else:
color = FOREGROUND_WHITE
args[0]._set_color(color)
ret = fn(*args)
args[0]._set_color(FOREGROUND_WHITE)
# print "after"
return ret
return new
def add_coloring_to_emit_ansi(fn):
# add methods we need to the class
def new(*args):
levelno = args[1].levelno
if (levelno >= 50):
color = '\x1b[31m' # red
elif (levelno >= 40):
color = '\x1b[31m' # red
elif (levelno >= 30):
color = '\x1b[33m' # yellow
elif (levelno >= 20):
color = '\x1b[32m' # green
elif (levelno >= 10):
color = '\x1b[35m' # pink
else:
color = '\x1b[0m' # normal
args[1].msg = color + args[1].msg + '\x1b[0m' # normal
# print "after"
return fn(*args)
return new
if platform.system() == 'Windows':
# Windows does not support ANSI escapes and we are using API calls to set the console color
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
else:
# all non-Windows platforms are supporting ANSI escapes so we use them
logging.StreamHandler.emit = add_coloring_to_emit_ansi(logging.StreamHandler.emit)
# log = logging.getLogger()
# log.addFilter(log_filter())
# //hdlr = logging.StreamHandler()
# //hdlr.setFormatter(formatter())
logger = logging.getLogger('text-generation-webui') logger = logging.getLogger('text-generation-webui')
logger.setLevel(logging.DEBUG)
def setup_logging():
'''
Copied from: https://github.com/vladmandic/automatic
All credits to vladmandic.
'''
class RingBuffer(logging.StreamHandler):
def __init__(self, capacity):
super().__init__()
self.capacity = capacity
self.buffer = []
self.formatter = logging.Formatter('{ "asctime":"%(asctime)s", "created":%(created)f, "facility":"%(name)s", "pid":%(process)d, "tid":%(thread)d, "level":"%(levelname)s", "module":"%(module)s", "func":"%(funcName)s", "msg":"%(message)s" }')
def emit(self, record):
msg = self.format(record)
# self.buffer.append(json.loads(msg))
self.buffer.append(msg)
if len(self.buffer) > self.capacity:
self.buffer.pop(0)
def get(self):
return self.buffer
from rich.console import Console
from rich.logging import RichHandler
from rich.pretty import install as pretty_install
from rich.theme import Theme
from rich.traceback import install as traceback_install
level = logging.DEBUG
logger.setLevel(logging.DEBUG) # log to file is always at level debug for facility `sd`
console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({
"traceback.border": "black",
"traceback.border.syntax_error": "black",
"inspect.value.border": "black",
}))
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', handlers=[logging.NullHandler()]) # redirect default logger to null
pretty_install(console=console)
traceback_install(console=console, extra_lines=1, max_frames=10, width=console.width, word_wrap=False, indent_guides=False, suppress=[])
while logger.hasHandlers() and len(logger.handlers) > 0:
logger.removeHandler(logger.handlers[0])
# handlers
rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=level, console=console)
rh.setLevel(level)
logger.addHandler(rh)
rb = RingBuffer(100) # 100 entries default in log ring buffer
rb.setLevel(level)
logger.addHandler(rb)
logger.buffer = rb.buffer
# overrides
logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("diffusers").setLevel(logging.ERROR)
logging.getLogger("torch").setLevel(logging.ERROR)
logging.getLogger("lycoris").handlers = logger.handlers
setup_logging()

View File

@ -54,7 +54,7 @@ sampler_hijack.hijack_samplers()
def load_model(model_name, loader=None): def load_model(model_name, loader=None):
logger.info(f"Loading {model_name}...") logger.info(f"Loading {model_name}")
t0 = time.time() t0 = time.time()
shared.is_seq2seq = False shared.is_seq2seq = False
@ -413,8 +413,8 @@ def ExLlamav2_HF_loader(model_name):
def HQQ_loader(model_name): def HQQ_loader(model_name):
from hqq.core.quantize import HQQBackend, HQQLinear
from hqq.engine.hf import HQQModelForCausalLM from hqq.engine.hf import HQQModelForCausalLM
from hqq.core.quantize import HQQLinear, HQQBackend
logger.info(f"Loading HQQ model with backend: {shared.args.hqq_backend}") logger.info(f"Loading HQQ model with backend: {shared.args.hqq_backend}")

View File

@ -106,6 +106,7 @@ parser.add_argument('--compute_dtype', type=str, default='float16', help='comput
parser.add_argument('--quant_type', type=str, default='nf4', help='quant_type for 4-bit. Valid options: nf4, fp4.') parser.add_argument('--quant_type', type=str, default='nf4', help='quant_type for 4-bit. Valid options: nf4, fp4.')
# llama.cpp # llama.cpp
parser.add_argument('--tensorcores', action='store_true', help='Use llama-cpp-python compiled with tensor cores support. This increases performance on RTX cards. NVIDIA only.')
parser.add_argument('--n_ctx', type=int, default=2048, help='Size of the prompt context.') parser.add_argument('--n_ctx', type=int, default=2048, help='Size of the prompt context.')
parser.add_argument('--threads', type=int, default=0, help='Number of threads to use.') parser.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
parser.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.') parser.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.')
@ -203,22 +204,26 @@ for arg in sys.argv[1:]:
if hasattr(args, arg): if hasattr(args, arg):
provided_arguments.append(arg) provided_arguments.append(arg)
# Deprecation warnings
deprecated_args = ['notebook', 'chat', 'no_stream', 'mul_mat_q', 'use_fast'] deprecated_args = ['notebook', 'chat', 'no_stream', 'mul_mat_q', 'use_fast']
for k in deprecated_args:
if getattr(args, k):
logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.')
# Security warnings
if args.trust_remote_code: def do_cmd_flags_warnings():
logger.warning('trust_remote_code is enabled. This is dangerous.')
if 'COLAB_GPU' not in os.environ and not args.nowebui: # Deprecation warnings
if args.share: for k in deprecated_args:
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.") if getattr(args, k):
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)): logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.')
logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
if args.multi_user: # Security warnings
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.') if args.trust_remote_code:
logger.warning('trust_remote_code is enabled. This is dangerous.')
if 'COLAB_GPU' not in os.environ and not args.nowebui:
if args.share:
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
if args.multi_user:
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')
def fix_loader_name(name): def fix_loader_name(name):

View File

@ -249,7 +249,7 @@ def backup_adapter(input_folder):
adapter_file = Path(f"{input_folder}/adapter_model.bin") adapter_file = Path(f"{input_folder}/adapter_model.bin")
if adapter_file.is_file(): if adapter_file.is_file():
logger.info("Backing up existing LoRA adapter...") logger.info("Backing up existing LoRA adapter")
creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime) creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime)
creation_date_str = creation_date.strftime("Backup-%Y-%m-%d") creation_date_str = creation_date.strftime("Backup-%Y-%m-%d")
@ -406,7 +406,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
# == Prep the dataset, format, etc == # == Prep the dataset, format, etc ==
if raw_text_file not in ['None', '']: if raw_text_file not in ['None', '']:
train_template["template_type"] = "raw_text" train_template["template_type"] = "raw_text"
logger.info("Loading raw text file dataset...") logger.info("Loading raw text file dataset")
fullpath = clean_path('training/datasets', f'{raw_text_file}') fullpath = clean_path('training/datasets', f'{raw_text_file}')
fullpath = Path(fullpath) fullpath = Path(fullpath)
if fullpath.is_dir(): if fullpath.is_dir():
@ -486,7 +486,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
prompt = generate_prompt(data_point) prompt = generate_prompt(data_point)
return tokenize(prompt, add_eos_token) return tokenize(prompt, add_eos_token)
logger.info("Loading JSON datasets...") logger.info("Loading JSON datasets")
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json')) data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30)) train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
@ -516,13 +516,13 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
# == Start prepping the model itself == # == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'): if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
logger.info("Getting model ready...") logger.info("Getting model ready")
prepare_model_for_kbit_training(shared.model) prepare_model_for_kbit_training(shared.model)
# base model is now frozen and should not be reused for any other LoRA training than this one # base model is now frozen and should not be reused for any other LoRA training than this one
shared.model_dirty_from_training = True shared.model_dirty_from_training = True
logger.info("Preparing for training...") logger.info("Preparing for training")
config = LoraConfig( config = LoraConfig(
r=lora_rank, r=lora_rank,
lora_alpha=lora_alpha, lora_alpha=lora_alpha,
@ -540,10 +540,10 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
model_trainable_params, model_all_params = calc_trainable_parameters(shared.model) model_trainable_params, model_all_params = calc_trainable_parameters(shared.model)
try: try:
logger.info("Creating LoRA model...") logger.info("Creating LoRA model")
lora_model = get_peft_model(shared.model, config) lora_model = get_peft_model(shared.model, config)
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file(): if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
logger.info("Loading existing LoRA data...") logger.info("Loading existing LoRA data")
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True) state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True)
set_peft_model_state_dict(lora_model, state_dict_peft) set_peft_model_state_dict(lora_model, state_dict_peft)
except: except:
@ -648,7 +648,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
json.dump(train_template, file, indent=2) json.dump(train_template, file, indent=2)
# == Main run and monitor loop == # == Main run and monitor loop ==
logger.info("Starting training...") logger.info("Starting training")
yield "Starting..." yield "Starting..."
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model) lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
@ -730,7 +730,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
# Saving in the train thread might fail if an error occurs, so save here if so. # Saving in the train thread might fail if an error occurs, so save here if so.
if not tracked.did_save: if not tracked.did_save:
logger.info("Training complete, saving...") logger.info("Training complete, saving")
lora_model.save_pretrained(lora_file_path) lora_model.save_pretrained(lora_file_path)
if WANT_INTERRUPT: if WANT_INTERRUPT:

View File

@ -92,6 +92,7 @@ def list_model_elements():
'numa', 'numa',
'logits_all', 'logits_all',
'no_offload_kqv', 'no_offload_kqv',
'tensorcores',
'hqq_backend', 'hqq_backend',
] ]
if is_torch_xpu_available(): if is_torch_xpu_available():

View File

@ -105,6 +105,8 @@ def create_ui():
shared.gradio['quipsharp_info'] = gr.Markdown('QuIP# only works on Linux.') shared.gradio['quipsharp_info'] = gr.Markdown('QuIP# only works on Linux.')
with gr.Column(): with gr.Column():
shared.gradio['tensorcores'] = gr.Checkbox(label="tensorcores", value=shared.args.tensorcores, info='Use llama-cpp-python compiled with tensor cores support. This increases performance on RTX cards. NVIDIA only.')
shared.gradio['no_offload_kqv'] = gr.Checkbox(label="no_offload_kqv", value=shared.args.no_offload_kqv, info='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.')
shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton) shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton)
shared.gradio['no_inject_fused_attention'] = gr.Checkbox(label="no_inject_fused_attention", value=shared.args.no_inject_fused_attention, info='Disable fused attention. Fused attention improves inference performance but uses more VRAM. Fuses layers for AutoAWQ. Disable if running low on VRAM.') shared.gradio['no_inject_fused_attention'] = gr.Checkbox(label="no_inject_fused_attention", value=shared.args.no_inject_fused_attention, info='Disable fused attention. Fused attention improves inference performance but uses more VRAM. Fuses layers for AutoAWQ. Disable if running low on VRAM.')
shared.gradio['no_inject_fused_mlp'] = gr.Checkbox(label="no_inject_fused_mlp", value=shared.args.no_inject_fused_mlp, info='Affects Triton only. Disable fused MLP. Fused MLP improves performance but uses more VRAM. Disable if running low on VRAM.') shared.gradio['no_inject_fused_mlp'] = gr.Checkbox(label="no_inject_fused_mlp", value=shared.args.no_inject_fused_mlp, info='Affects Triton only. Disable fused MLP. Fused MLP improves performance but uses more VRAM. Disable if running low on VRAM.')
@ -115,7 +117,6 @@ def create_ui():
shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock) shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
shared.gradio['numa'] = gr.Checkbox(label="numa", value=shared.args.numa, info='NUMA support can help on some systems with non-uniform memory access.') shared.gradio['numa'] = gr.Checkbox(label="numa", value=shared.args.numa, info='NUMA support can help on some systems with non-uniform memory access.')
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu) shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu)
shared.gradio['no_offload_kqv'] = gr.Checkbox(label="no_offload_kqv", value=shared.args.no_offload_kqv, info='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.')
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit) shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16) shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices) shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices)

View File

@ -4,7 +4,7 @@ datasets
einops einops
exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64" exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64"
gradio==3.50.* gradio==3.50.*
hqq==0.1.1 hqq==0.1.1.post1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*
@ -13,6 +13,7 @@ peft==0.7.*
Pillow>=9.5.0 Pillow>=9.5.0
pyyaml pyyaml
requests requests
rich
safetensors==0.4.1 safetensors==0.4.1
scipy scipy
sentencepiece sentencepiece
@ -35,6 +36,26 @@ https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cp
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.24+cpuavx2-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9" https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.24+cpuavx2-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.24+cpuavx2-cp38-cp38-win_amd64.whl; platform_system == "Windows" and python_version == "3.8" https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.24+cpuavx2-cp38-cp38-win_amd64.whl; platform_system == "Windows" and python_version == "3.8"
# llama-cpp-python (CUDA, no tensor cores)
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121-cp38-cp38-win_amd64.whl; platform_system == "Windows" and python_version == "3.8"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121-cp311-cp311-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121-cp310-cp310-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121-cp39-cp39-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.9"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121-cp38-cp38-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.8"
# llama-cpp-python (CUDA, tensor cores)
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.24+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.24+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.24+cu121-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.24+cu121-cp38-cp38-win_amd64.whl; platform_system == "Windows" and python_version == "3.8"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.24+cu121-cp311-cp311-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.24+cu121-cp310-cp310-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.24+cu121-cp39-cp39-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.9"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.24+cu121-cp38-cp38-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.8"
# CUDA wheels # CUDA wheels
https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
@ -68,14 +89,6 @@ https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.4/flash_attn
https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.4/flash_attn-2.3.4+cu122torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.4/flash_attn-2.3.4+cu122torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.4/flash_attn-2.3.4+cu122torch2.1cxx11abiFALSE-cp39-cp39-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.9" https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.4/flash_attn-2.3.4+cu122torch2.1cxx11abiFALSE-cp39-cp39-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.9"
https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.4/flash_attn-2.3.4+cu122torch2.1cxx11abiFALSE-cp38-cp38-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.8" https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.4/flash_attn-2.3.4+cu122torch2.1cxx11abiFALSE-cp38-cp38-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.8"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121-cp38-cp38-win_amd64.whl; platform_system == "Windows" and python_version == "3.8"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121-cp311-cp311-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121-cp310-cp310-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121-cp39-cp39-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.9"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121-cp38-cp38-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.8"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9" https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9"

View File

@ -4,7 +4,7 @@ datasets
einops einops
exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64" exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64"
gradio==3.50.* gradio==3.50.*
hqq==0.1.1 hqq==0.1.1.post1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*
@ -13,6 +13,7 @@ peft==0.7.*
Pillow>=9.5.0 Pillow>=9.5.0
pyyaml pyyaml
requests requests
rich
safetensors==0.4.1 safetensors==0.4.1
scipy scipy
sentencepiece sentencepiece

View File

@ -4,7 +4,7 @@ datasets
einops einops
exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64" exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64"
gradio==3.50.* gradio==3.50.*
hqq==0.1.1 hqq==0.1.1.post1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*
@ -13,6 +13,7 @@ peft==0.7.*
Pillow>=9.5.0 Pillow>=9.5.0
pyyaml pyyaml
requests requests
rich
safetensors==0.4.1 safetensors==0.4.1
scipy scipy
sentencepiece sentencepiece

View File

@ -4,7 +4,7 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1 hqq==0.1.1.post1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*
@ -13,6 +13,7 @@ peft==0.7.*
Pillow>=9.5.0 Pillow>=9.5.0
pyyaml pyyaml
requests requests
rich
safetensors==0.4.1 safetensors==0.4.1
scipy scipy
sentencepiece sentencepiece

View File

@ -4,7 +4,7 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1 hqq==0.1.1.post1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*
@ -13,6 +13,7 @@ peft==0.7.*
Pillow>=9.5.0 Pillow>=9.5.0
pyyaml pyyaml
requests requests
rich
safetensors==0.4.1 safetensors==0.4.1
scipy scipy
sentencepiece sentencepiece

View File

@ -4,7 +4,7 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1 hqq==0.1.1.post1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*
@ -13,6 +13,7 @@ peft==0.7.*
Pillow>=9.5.0 Pillow>=9.5.0
pyyaml pyyaml
requests requests
rich
safetensors==0.4.1 safetensors==0.4.1
scipy scipy
sentencepiece sentencepiece

View File

@ -4,7 +4,7 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1 hqq==0.1.1.post1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*
@ -13,6 +13,7 @@ peft==0.7.*
Pillow>=9.5.0 Pillow>=9.5.0
pyyaml pyyaml
requests requests
rich
safetensors==0.4.1 safetensors==0.4.1
scipy scipy
sentencepiece sentencepiece

View File

@ -4,7 +4,7 @@ datasets
einops einops
exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64" exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64"
gradio==3.50.* gradio==3.50.*
hqq==0.1.1 hqq==0.1.1.post1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*
@ -13,6 +13,7 @@ peft==0.7.*
Pillow>=9.5.0 Pillow>=9.5.0
pyyaml pyyaml
requests requests
rich
safetensors==0.4.1 safetensors==0.4.1
scipy scipy
sentencepiece sentencepiece
@ -35,6 +36,26 @@ https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cp
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.24+cpuavx-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9" https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.24+cpuavx-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.24+cpuavx-cp38-cp38-win_amd64.whl; platform_system == "Windows" and python_version == "3.8" https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.2.24+cpuavx-cp38-cp38-win_amd64.whl; platform_system == "Windows" and python_version == "3.8"
# llama-cpp-python (CUDA, no tensor cores)
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121avx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121avx-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121avx-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121avx-cp38-cp38-win_amd64.whl; platform_system == "Windows" and python_version == "3.8"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121avx-cp311-cp311-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121avx-cp310-cp310-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121avx-cp39-cp39-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.9"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121avx-cp38-cp38-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.8"
# llama-cpp-python (CUDA, tensor cores)
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.24+cu121avx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.24+cu121avx-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.24+cu121avx-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.24+cu121avx-cp38-cp38-win_amd64.whl; platform_system == "Windows" and python_version == "3.8"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.24+cu121avx-cp311-cp311-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.24+cu121avx-cp310-cp310-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.24+cu121avx-cp39-cp39-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.9"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.2.24+cu121avx-cp38-cp38-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.8"
# CUDA wheels # CUDA wheels
https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" https://github.com/jllllll/AutoGPTQ/releases/download/v0.6.0/auto_gptq-0.6.0+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
@ -68,14 +89,6 @@ https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.4/flash_attn
https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.4/flash_attn-2.3.4+cu122torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.4/flash_attn-2.3.4+cu122torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.4/flash_attn-2.3.4+cu122torch2.1cxx11abiFALSE-cp39-cp39-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.9" https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.4/flash_attn-2.3.4+cu122torch2.1cxx11abiFALSE-cp39-cp39-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.9"
https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.4/flash_attn-2.3.4+cu122torch2.1cxx11abiFALSE-cp38-cp38-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.8" https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.4/flash_attn-2.3.4+cu122torch2.1cxx11abiFALSE-cp38-cp38-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.8"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121avx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121avx-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121avx-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121avx-cp38-cp38-win_amd64.whl; platform_system == "Windows" and python_version == "3.8"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121avx-cp311-cp311-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121avx-cp310-cp310-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121avx-cp39-cp39-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.9"
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.2.24+cu121avx-cp38-cp38-manylinux_2_31_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.8"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9" https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.1/gptq_for_llama-0.1.1+cu121-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9"

View File

@ -4,7 +4,7 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1 hqq==0.1.1.post1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*
@ -13,6 +13,7 @@ peft==0.7.*
Pillow>=9.5.0 Pillow>=9.5.0
pyyaml pyyaml
requests requests
rich
safetensors==0.4.1 safetensors==0.4.1
scipy scipy
sentencepiece sentencepiece

View File

@ -12,6 +12,8 @@ os.environ['BITSANDBYTES_NOWELCOME'] = '1'
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated')
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict') warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict')
warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()')
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_names" has conflict')
with RequestBlocker(): with RequestBlocker():
import gradio as gr import gradio as gr
@ -54,6 +56,7 @@ from modules.models_settings import (
get_model_metadata, get_model_metadata,
update_model_parameters update_model_parameters
) )
from modules.shared import do_cmd_flags_warnings
from modules.utils import gradio from modules.utils import gradio
@ -170,6 +173,9 @@ def create_interface():
if __name__ == "__main__": if __name__ == "__main__":
logger.info("Starting Text generation web UI")
do_cmd_flags_warnings()
# Load custom settings # Load custom settings
settings_file = None settings_file = None
if shared.args.settings is not None and Path(shared.args.settings).exists(): if shared.args.settings is not None and Path(shared.args.settings).exists():
@ -180,7 +186,7 @@ if __name__ == "__main__":
settings_file = Path('settings.json') settings_file = Path('settings.json')
if settings_file is not None: if settings_file is not None:
logger.info(f"Loading settings from {settings_file}...") logger.info(f"Loading settings from {settings_file}")
file_contents = open(settings_file, 'r', encoding='utf-8').read() file_contents = open(settings_file, 'r', encoding='utf-8').read()
new_settings = json.loads(file_contents) if settings_file.suffix == "json" else yaml.safe_load(file_contents) new_settings = json.loads(file_contents) if settings_file.suffix == "json" else yaml.safe_load(file_contents)
shared.settings.update(new_settings) shared.settings.update(new_settings)