Better warning messages

This commit is contained in:
oobabooga 2023-05-03 21:43:17 -03:00
parent 0a48b29cd8
commit 95d04d6a8d
13 changed files with 194 additions and 83 deletions

View File

@ -1,4 +1,5 @@
import inspect import inspect
import logging
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
@ -71,7 +72,6 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
del layers del layers
print('Loading model ...')
if checkpoint.endswith('.safetensors'): if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint), strict=False) model.load_state_dict(safe_load(checkpoint), strict=False)
@ -90,8 +90,6 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
quant.autotune_warmup_fused(model) quant.autotune_warmup_fused(model)
model.seqlen = 2048 model.seqlen = 2048
print('Done.')
return model return model
@ -119,11 +117,13 @@ def find_quantized_model_file(model_name):
if len(found_pts) > 0: if len(found_pts) > 0:
if len(found_pts) > 1: if len(found_pts) > 1:
print('Warning: more than one .pt model has been found. The last one will be selected. It could be wrong.') logging.warning('More than one .pt model has been found. The last one will be selected. It could be wrong.')
pt_path = found_pts[-1] pt_path = found_pts[-1]
elif len(found_safetensors) > 0: elif len(found_safetensors) > 0:
if len(found_pts) > 1: if len(found_pts) > 1:
print('Warning: more than one .safetensors model has been found. The last one will be selected. It could be wrong.') logging.warning('More than one .safetensors model has been found. The last one will be selected. It could be wrong.')
pt_path = found_safetensors[-1] pt_path = found_safetensors[-1]
return pt_path return pt_path
@ -142,8 +142,7 @@ def load_quantized(model_name):
elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])): elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])):
model_type = 'gptj' model_type = 'gptj'
else: else:
print("Can't determine model type from model name. Please specify it manually using --model_type " logging.error("Can't determine model type from model name. Please specify it manually using --model_type argument")
"argument")
exit() exit()
else: else:
model_type = shared.args.model_type.lower() model_type = shared.args.model_type.lower()
@ -153,20 +152,21 @@ def load_quantized(model_name):
load_quant = llama_inference_offload.load_quant load_quant = llama_inference_offload.load_quant
elif model_type in ('llama', 'opt', 'gptj'): elif model_type in ('llama', 'opt', 'gptj'):
if shared.args.pre_layer: if shared.args.pre_layer:
print("Warning: ignoring --pre_layer because it only works for llama model type.") logging.warning("Ignoring --pre_layer because it only works for llama model type.")
load_quant = _load_quant load_quant = _load_quant
else: else:
print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported") logging.error("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
exit() exit()
# Find the quantized model weights file (.pt/.safetensors) # Find the quantized model weights file (.pt/.safetensors)
path_to_model = Path(f'{shared.args.model_dir}/{model_name}') path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
pt_path = find_quantized_model_file(model_name) pt_path = find_quantized_model_file(model_name)
if not pt_path: if not pt_path:
print("Could not find the quantized model in .pt or .safetensors format, exiting...") logging.error("Could not find the quantized model in .pt or .safetensors format, exiting...")
exit() exit()
else: else:
print(f"Found the following quantized model: {pt_path}") logging.info(f"Found the following quantized model: {pt_path}")
# qwopqwop200's offload # qwopqwop200's offload
if model_type == 'llama' and shared.args.pre_layer: if model_type == 'llama' and shared.args.pre_layer:
@ -188,7 +188,7 @@ def load_quantized(model_name):
max_memory = accelerate.utils.get_balanced_memory(model) max_memory = accelerate.utils.get_balanced_memory(model)
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"]) device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
print("Using the following device map for the quantized model:", device_map) logging.info("Using the following device map for the quantized model:", device_map)
# https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model # https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True) model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)

View File

@ -1,3 +1,4 @@
import logging
from pathlib import Path from pathlib import Path
import torch import torch
@ -18,7 +19,7 @@ def add_lora_to_model(lora_names):
# Add a LoRA when another LoRA is already present # Add a LoRA when another LoRA is already present
if len(removed_set) == 0 and len(prior_set) > 0: if len(removed_set) == 0 and len(prior_set) > 0:
print(f"Adding the LoRA(s) named {added_set} to the model...") logging.info(f"Adding the LoRA(s) named {added_set} to the model...")
for lora in added_set: for lora in added_set:
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora) shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
@ -29,7 +30,7 @@ def add_lora_to_model(lora_names):
shared.model.disable_adapter() shared.model.disable_adapter()
if len(lora_names) > 0: if len(lora_names) > 0:
print("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names))) logging.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
params = {} params = {}
if not shared.args.cpu: if not shared.args.cpu:
params['dtype'] = shared.model.dtype params['dtype'] = shared.model.dtype

View File

@ -3,6 +3,7 @@ import base64
import copy import copy
import io import io
import json import json
import logging
import re import re
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@ -138,7 +139,7 @@ def extract_message_from_reply(reply, state):
def chatbot_wrapper(text, state, regenerate=False, _continue=False): def chatbot_wrapper(text, state, regenerate=False, _continue=False):
if shared.model_name == 'None' or shared.model is None: if shared.model_name == 'None' or shared.model is None:
print("No model is loaded! Select one in the Model tab.") logging.error("No model is loaded! Select one in the Model tab.")
yield shared.history['visible'] yield shared.history['visible']
return return
@ -216,7 +217,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
def impersonate_wrapper(text, state): def impersonate_wrapper(text, state):
if shared.model_name == 'None' or shared.model is None: if shared.model_name == 'None' or shared.model is None:
print("No model is loaded! Select one in the Model tab.") logging.error("No model is loaded! Select one in the Model tab.")
yield '' yield ''
return return
@ -523,7 +524,7 @@ def upload_character(json_file, img, tavern=False):
img = Image.open(io.BytesIO(img)) img = Image.open(io.BytesIO(img))
img.save(Path(f'characters/{outfile_name}.png')) img.save(Path(f'characters/{outfile_name}.png'))
print(f'New character saved to "characters/{outfile_name}.json".') logging.info(f'New character saved to "characters/{outfile_name}.json".')
return outfile_name return outfile_name
@ -547,6 +548,6 @@ def upload_your_profile_picture(img, name1, name2, mode):
else: else:
img = make_thumbnail(img) img = make_thumbnail(img)
img.save(Path('cache/pfp_me.png')) img.save(Path('cache/pfp_me.png'))
print('Profile picture saved to "cache/pfp_me.png"') logging.info('Profile picture saved to "cache/pfp_me.png"')
return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True) return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)

View File

@ -1,3 +1,4 @@
import logging
import traceback import traceback
from functools import partial from functools import partial
@ -28,7 +29,7 @@ def load_extensions():
for i, name in enumerate(shared.args.extensions): for i, name in enumerate(shared.args.extensions):
if name in available_extensions: if name in available_extensions:
if name != 'api': if name != 'api':
print(f'Loading the extension "{name}"... ', end='') logging.info(f'Loading the extension "{name}"...')
try: try:
exec(f"import extensions.{name}.script") exec(f"import extensions.{name}.script")
extension = getattr(extensions, name).script extension = getattr(extensions, name).script
@ -38,12 +39,8 @@ def load_extensions():
extension.setup() extension.setup()
state[name] = [True, i] state[name] = [True, i]
if name != 'api':
print('Ok.')
except: except:
if name != 'api': logging.error('Failed to load the extension "{name}".')
print('Fail.')
traceback.print_exc() traceback.print_exc()

View File

@ -1,29 +1,28 @@
import logging
import math import math
import sys import sys
from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import transformers.models.llama.modeling_llama import transformers.models.llama.modeling_llama
from typing import Optional
from typing import Tuple
import modules.shared as shared import modules.shared as shared
if shared.args.xformers: if shared.args.xformers:
try: try:
import xformers.ops import xformers.ops
except Exception: except Exception:
print("🔴 xformers not found! Please install it before trying to use it.", file=sys.stderr) logging.error("xformers not found! Please install it before trying to use it.", file=sys.stderr)
def hijack_llama_attention(): def hijack_llama_attention():
if shared.args.xformers: if shared.args.xformers:
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
print("Replaced attention with xformers_attention") logging.info("Replaced attention with xformers_attention")
elif shared.args.sdp_attention: elif shared.args.sdp_attention:
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
print("Replaced attention with sdp_attention") logging.info("Replaced attention with sdp_attention")
def xformers_forward( def xformers_forward(
@ -57,8 +56,6 @@ def xformers_forward(
# We only apply xformers optimizations if we don't need to output the whole attention matrix # We only apply xformers optimizations if we don't need to output the whole attention matrix
if not output_attentions: if not output_attentions:
dtype = query_states.dtype
query_states = query_states.transpose(1, 2) query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
@ -102,9 +99,7 @@ def xformers_forward(
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value

109
modules/logging_colors.py Normal file
View File

@ -0,0 +1,109 @@
#!/usr/bin/env python
# encoding: utf-8
import logging
# now we patch Python code to add color support to logging.StreamHandler
def add_coloring_to_emit_windows(fn):
# add methods we need to the class
def _out_handle(self):
import ctypes
return ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
out_handle = property(_out_handle)
def _set_color(self, code):
import ctypes
# Constants from the Windows API
self.STD_OUTPUT_HANDLE = -11
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
ctypes.windll.kernel32.SetConsoleTextAttribute(hdl, code)
setattr(logging.StreamHandler, '_set_color', _set_color)
def new(*args):
FOREGROUND_BLUE = 0x0001 # text color contains blue.
FOREGROUND_GREEN = 0x0002 # text color contains green.
FOREGROUND_RED = 0x0004 # text color contains red.
FOREGROUND_INTENSITY = 0x0008 # text color is intensified.
FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED
# winbase.h
# STD_INPUT_HANDLE = -10
# STD_OUTPUT_HANDLE = -11
# STD_ERROR_HANDLE = -12
# wincon.h
# FOREGROUND_BLACK = 0x0000
FOREGROUND_BLUE = 0x0001
FOREGROUND_GREEN = 0x0002
# FOREGROUND_CYAN = 0x0003
FOREGROUND_RED = 0x0004
FOREGROUND_MAGENTA = 0x0005
FOREGROUND_YELLOW = 0x0006
# FOREGROUND_GREY = 0x0007
FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified.
# BACKGROUND_BLACK = 0x0000
# BACKGROUND_BLUE = 0x0010
# BACKGROUND_GREEN = 0x0020
# BACKGROUND_CYAN = 0x0030
# BACKGROUND_RED = 0x0040
# BACKGROUND_MAGENTA = 0x0050
BACKGROUND_YELLOW = 0x0060
# BACKGROUND_GREY = 0x0070
BACKGROUND_INTENSITY = 0x0080 # background color is intensified.
levelno = args[1].levelno
if (levelno >= 50):
color = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY
elif (levelno >= 40):
color = FOREGROUND_RED | FOREGROUND_INTENSITY
elif (levelno >= 30):
color = FOREGROUND_YELLOW | FOREGROUND_INTENSITY
elif (levelno >= 20):
color = FOREGROUND_GREEN
elif (levelno >= 10):
color = FOREGROUND_MAGENTA
else:
color = FOREGROUND_WHITE
args[0]._set_color(color)
ret = fn(*args)
args[0]._set_color(FOREGROUND_WHITE)
# print "after"
return ret
return new
def add_coloring_to_emit_ansi(fn):
# add methods we need to the class
def new(*args):
levelno = args[1].levelno
if (levelno >= 50):
color = '\x1b[31m' # red
elif (levelno >= 40):
color = '\x1b[31m' # red
elif (levelno >= 30):
color = '\x1b[33m' # yellow
elif (levelno >= 20):
color = '\x1b[32m' # green
elif (levelno >= 10):
color = '\x1b[35m' # pink
else:
color = '\x1b[0m' # normal
args[1].msg = color + args[1].msg + '\x1b[0m' # normal
# print "after"
return fn(*args)
return new
import platform
if platform.system() == 'Windows':
# Windows does not support ANSI escapes and we are using API calls to set the console color
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
else:
# all non-Windows platforms are supporting ANSI escapes so we use them
logging.StreamHandler.emit = add_coloring_to_emit_ansi(logging.StreamHandler.emit)
# log = logging.getLogger()
# log.addFilter(log_filter())
# //hdlr = logging.StreamHandler()
# //hdlr.setFormatter(formatter())

View File

@ -1,5 +1,6 @@
import gc import gc
import json import json
import logging
import os import os
import re import re
import time import time
@ -65,7 +66,7 @@ def find_model_type(model_name):
def load_model(model_name): def load_model(model_name):
print(f"Loading {model_name}...") logging.info(f"Loading {model_name}...")
t0 = time.time() t0 = time.time()
shared.model_type = find_model_type(model_name) shared.model_type = find_model_type(model_name)
@ -116,7 +117,7 @@ def load_model(model_name):
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
model.module.eval() # Inference model.module.eval() # Inference
print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") logging.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
# RMKV model (not on HuggingFace) # RMKV model (not on HuggingFace)
elif shared.model_type == 'rwkv': elif shared.model_type == 'rwkv':
@ -137,7 +138,7 @@ def load_model(model_name):
else: else:
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0] model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
print(f"llama.cpp weights detected: {model_file}\n") logging.info(f"llama.cpp weights detected: {model_file}\n")
model, tokenizer = LlamaCppModel.from_pretrained(model_file) model, tokenizer = LlamaCppModel.from_pretrained(model_file)
return model, tokenizer return model, tokenizer
@ -146,7 +147,7 @@ def load_model(model_name):
# Monkey patch # Monkey patch
if shared.args.monkey_patch: if shared.args.monkey_patch:
print("Warning: applying the monkey patch for using LoRAs in 4-bit mode.\nIt may cause undefined behavior outside its intended scope.") logging.warning("Warning: applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope.")
from modules.monkey_patch_gptq_lora import load_model_llama from modules.monkey_patch_gptq_lora import load_model_llama
model, _ = load_model_llama(model_name) model, _ = load_model_llama(model_name)
@ -161,7 +162,7 @@ def load_model(model_name):
else: else:
params = {"low_cpu_mem_usage": True} params = {"low_cpu_mem_usage": True}
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)): if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
print("Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n") logging.warning("Warning: torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.")
shared.args.cpu = True shared.args.cpu = True
if shared.args.cpu: if shared.args.cpu:
@ -184,6 +185,7 @@ def load_model(model_name):
max_memory = {} max_memory = {}
for i in range(len(memory_map)): for i in range(len(memory_map)):
max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i] max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
max_memory['cpu'] = max_cpu_memory max_memory['cpu'] = max_cpu_memory
params['max_memory'] = max_memory params['max_memory'] = max_memory
elif shared.args.auto_devices: elif shared.args.auto_devices:
@ -191,9 +193,9 @@ def load_model(model_name):
suggestion = round((total_mem - 1000) / 1000) * 1000 suggestion = round((total_mem - 1000) / 1000) * 1000
if total_mem - suggestion < 800: if total_mem - suggestion < 800:
suggestion -= 1000 suggestion -= 1000
suggestion = int(round(suggestion / 1000))
print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
suggestion = int(round(suggestion / 1000))
logging.warning(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'} max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
params['max_memory'] = max_memory params['max_memory'] = max_memory
@ -201,11 +203,11 @@ def load_model(model_name):
params["offload_folder"] = shared.args.disk_cache_dir params["offload_folder"] = shared.args.disk_cache_dir
checkpoint = Path(f'{shared.args.model_dir}/{model_name}') checkpoint = Path(f'{shared.args.model_dir}/{model_name}')
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
config = AutoConfig.from_pretrained(checkpoint) config = AutoConfig.from_pretrained(checkpoint)
with init_empty_weights(): with init_empty_weights():
model = LoaderClass.from_config(config) model = LoaderClass.from_config(config)
model.tie_weights() model.tie_weights()
params['device_map'] = infer_auto_device_map( params['device_map'] = infer_auto_device_map(
model, model,
@ -230,7 +232,7 @@ def load_model(model_name):
if shared.model_type != 'llava': if shared.model_type != 'llava':
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]: for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
if p.exists(): if p.exists():
print(f"Loading the universal LLaMA tokenizer from {p}...") logging.info(f"Loading the universal LLaMA tokenizer from {p}...")
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True) tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
break break
@ -247,7 +249,7 @@ def load_model(model_name):
else: else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=trust_remote_code) tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=trust_remote_code)
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") logging.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer return model, tokenizer
@ -276,20 +278,20 @@ def load_soft_prompt(name):
zf.extract('tensor.npy') zf.extract('tensor.npy')
zf.extract('meta.json') zf.extract('meta.json')
j = json.loads(open('meta.json', 'r').read()) j = json.loads(open('meta.json', 'r').read())
print(f"\nLoading the softprompt \"{name}\".") logging.info(f"\nLoading the softprompt \"{name}\".")
for field in j: for field in j:
if field != 'name': if field != 'name':
if type(j[field]) is list: if type(j[field]) is list:
print(f"{field}: {', '.join(j[field])}") logging.info(f"{field}: {', '.join(j[field])}")
else: else:
print(f"{field}: {j[field]}") logging.info(f"{field}: {j[field]}")
print() logging.info()
tensor = np.load('tensor.npy') tensor = np.load('tensor.npy')
Path('tensor.npy').unlink() Path('tensor.npy').unlink()
Path('meta.json').unlink() Path('meta.json').unlink()
tensor = torch.Tensor(tensor).to(device=shared.model.device, dtype=shared.model.dtype) tensor = torch.Tensor(tensor).to(device=shared.model.device, dtype=shared.model.dtype)
tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1])) tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
shared.soft_prompt = True shared.soft_prompt = True
shared.soft_prompt_tensor = tensor shared.soft_prompt_tensor = tensor

View File

@ -17,6 +17,7 @@ from modules.GPTQ_loader import find_quantized_model_file
replace_peft_model_with_gptq_lora_model() replace_peft_model_with_gptq_lora_model()
def load_model_llama(model_name): def load_model_llama(model_name):
config_path = str(Path(f'{shared.args.model_dir}/{model_name}')) config_path = str(Path(f'{shared.args.model_dir}/{model_name}'))
model_path = str(find_quantized_model_file(model_name)) model_path = str(find_quantized_model_file(model_name))

View File

@ -1,4 +1,5 @@
import argparse import argparse
import logging
from pathlib import Path from pathlib import Path
import yaml import yaml
@ -170,19 +171,19 @@ args_defaults = parser.parse_args([])
deprecated_dict = {} deprecated_dict = {}
for k in deprecated_dict: for k in deprecated_dict:
if getattr(args, k) != deprecated_dict[k][1]: if getattr(args, k) != deprecated_dict[k][1]:
print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.\n") logging.warning(f"--{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
setattr(args, deprecated_dict[k][0], getattr(args, k)) setattr(args, deprecated_dict[k][0], getattr(args, k))
# Deprecation warnings for parameters that have been removed # Deprecation warnings for parameters that have been removed
if args.cai_chat: if args.cai_chat:
print("Warning: --cai-chat is deprecated. Use --chat instead.\n") logging.warning("--cai-chat is deprecated. Use --chat instead.")
args.chat = True args.chat = True
# Security warnings # Security warnings
if args.trust_remote_code: if args.trust_remote_code:
print("Warning: trust_remote_code is enabled. This is dangerous.\n") logging.warning("trust_remote_code is enabled. This is dangerous.")
if args.share: if args.share:
print("Warning: the gradio \"share link\" feature downloads a proprietary and\nunaudited blob to create a reverse tunnel. This is potentially dangerous.\n") logging.warning("The gradio \"share link\" feature downloads a proprietary and unaudited blob to create a reverse tunnel. This is potentially dangerous.")
# Activating the API extension # Activating the API extension
if args.api or args.public_api: if args.api or args.public_api:

View File

@ -1,4 +1,5 @@
import ast import ast
import logging
import random import random
import re import re
import time import time
@ -175,7 +176,7 @@ def get_generate_params(state):
def generate_reply(question, state, eos_token=None, stopping_strings=[]): def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if shared.model_name == 'None' or shared.model is None: if shared.model_name == 'None' or shared.model is None:
print("No model is loaded! Select one in the Model tab.") logging.error("No model is loaded! Select one in the Model tab.")
yield formatted_outputs(question, shared.model_name) yield formatted_outputs(question, shared.model_name)
return return

View File

@ -1,4 +1,5 @@
import json import json
import logging
import math import math
import sys import sys
import threading import threading
@ -40,7 +41,6 @@ WANT_INTERRUPT = False
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer"] PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer"]
def get_datasets(path: str, ext: str): def get_datasets(path: str, ext: str):
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower) return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
@ -220,13 +220,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
if model_type == "PeftModelForCausalLM": if model_type == "PeftModelForCausalLM":
if len(shared.args.lora_names) > 0: if len(shared.args.lora_names) > 0:
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.") logging.warning("Training LoRA over top of another LoRA. May have unexpected effects.")
else: else:
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print("Warning: Model ID not matched due to LoRA loading. Consider reloading base model.") logging.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
else: else:
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})") logging.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
time.sleep(5) time.sleep(5)
if shared.args.wbits > 0 and not shared.args.monkey_patch: if shared.args.wbits > 0 and not shared.args.monkey_patch:
@ -235,7 +236,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
elif not shared.args.load_in_8bit and shared.args.wbits <= 0: elif not shared.args.load_in_8bit and shared.args.wbits <= 0:
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*" yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.") logging.warning("It is highly recommended you use `--load-in-8bit` for LoRA training.")
time.sleep(2) # Give it a moment for the message to show in UI before continuing time.sleep(2) # Give it a moment for the message to show in UI before continuing
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0: if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
@ -255,7 +256,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# == Prep the dataset, format, etc == # == Prep the dataset, format, etc ==
if raw_text_file not in ['None', '']: if raw_text_file not in ['None', '']:
print("Loading raw text file dataset...") logging.info("Loading raw text file dataset...")
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file: with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
raw_text = file.read() raw_text = file.read()
@ -299,7 +300,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
prompt = generate_prompt(data_point) prompt = generate_prompt(data_point)
return tokenize(prompt) return tokenize(prompt)
print("Loading JSON datasets...") logging.info("Loading JSON datasets...")
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json')) data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
train_data = data['train'].map(generate_and_tokenize_prompt) train_data = data['train'].map(generate_and_tokenize_prompt)
@ -311,10 +312,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# == Start prepping the model itself == # == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'): if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
print("Getting model ready...") logging.info("Getting model ready...")
prepare_model_for_int8_training(shared.model) prepare_model_for_int8_training(shared.model)
print("Prepping for training...") logging.info("Prepping for training...")
config = LoraConfig( config = LoraConfig(
r=lora_rank, r=lora_rank,
lora_alpha=lora_alpha, lora_alpha=lora_alpha,
@ -325,10 +326,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
) )
try: try:
print("Creating LoRA model...") logging.info("Creating LoRA model...")
lora_model = get_peft_model(shared.model, config) lora_model = get_peft_model(shared.model, config)
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file(): if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
print("Loading existing LoRA data...") logging.info("Loading existing LoRA data...")
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin") state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
set_peft_model_state_dict(lora_model, state_dict_peft) set_peft_model_state_dict(lora_model, state_dict_peft)
except: except:
@ -406,7 +407,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
json.dump({x: vars[x] for x in PARAMETERS}, file) json.dump({x: vars[x] for x in PARAMETERS}, file)
# == Main run and monitor loop == # == Main run and monitor loop ==
print("Starting training...") logging.info("Starting training...")
yield "Starting..." yield "Starting..."
if WANT_INTERRUPT: if WANT_INTERRUPT:
yield "Interrupted before start." yield "Interrupted before start."
@ -416,7 +417,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
trainer.train() trainer.train()
# Note: save in the thread in case the gradio thread breaks (eg browser closed) # Note: save in the thread in case the gradio thread breaks (eg browser closed)
lora_model.save_pretrained(lora_file_path) lora_model.save_pretrained(lora_file_path)
print("LoRA training run is completed and saved.") logging.info("LoRA training run is completed and saved.")
tracked.did_save = True tracked.did_save = True
thread = threading.Thread(target=threaded_run) thread = threading.Thread(target=threaded_run)
@ -448,14 +449,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# Saving in the train thread might fail if an error occurs, so save here if so. # Saving in the train thread might fail if an error occurs, so save here if so.
if not tracked.did_save: if not tracked.did_save:
print("Training complete, saving...") logging.info("Training complete, saving...")
lora_model.save_pretrained(lora_file_path) lora_model.save_pretrained(lora_file_path)
if WANT_INTERRUPT: if WANT_INTERRUPT:
print("Training interrupted.") logging.info("Training interrupted.")
yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`" yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`"
else: else:
print("Training complete!") logging.info("Training complete!")
yield f"Done! LoRA saved to `{lora_file_path}`" yield f"Done! LoRA saved to `{lora_file_path}`"

View File

@ -25,6 +25,7 @@ theme = gr.themes.Default(
background_fill_secondary='#eaeaea' background_fill_secondary='#eaeaea'
) )
def list_model_elements(): def list_model_elements():
elements = ['cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'wbits', 'groupsize', 'model_type', 'pre_layer'] elements = ['cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'wbits', 'groupsize', 'model_type', 'pre_layer']
for i in range(torch.cuda.device_count()): for i in range(torch.cuda.device_count()):

View File

@ -1,14 +1,17 @@
import logging
import os import os
import requests import requests
import warnings import warnings
import modules.logging_colors
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
os.environ['BITSANDBYTES_NOWELCOME'] = '1' os.environ['BITSANDBYTES_NOWELCOME'] = '1'
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)
# This is a hack to prevent Gradio from phoning home when it gets imported # This is a hack to prevent Gradio from phoning home when it gets imported
def my_get(url, **kwargs): def my_get(url, **kwargs):
print('Gradio HTTP request redirected to localhost :)') logging.info('Gradio HTTP request redirected to localhost :)')
kwargs.setdefault('allow_redirects', True) kwargs.setdefault('allow_redirects', True)
return requests.api.request('get', 'http://127.0.0.1/', **kwargs) return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
@ -17,9 +20,8 @@ requests.get = my_get
import gradio as gr import gradio as gr
requests.get = original_get requests.get = original_get
# This fixes LaTeX rendering on some systems
import matplotlib import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg') # This fixes LaTeX rendering on some systems
import importlib import importlib
import io import io
@ -39,7 +41,6 @@ import psutil
import torch import torch
import yaml import yaml
from PIL import Image from PIL import Image
import modules.extensions as extensions_module import modules.extensions as extensions_module
from modules import chat, shared, training, ui from modules import chat, shared, training, ui
from modules.html_generator import chat_html_wrapper from modules.html_generator import chat_html_wrapper
@ -860,7 +861,7 @@ if __name__ == "__main__":
elif Path('settings.json').exists(): elif Path('settings.json').exists():
settings_file = Path('settings.json') settings_file = Path('settings.json')
if settings_file is not None: if settings_file is not None:
print(f"Loading settings from {settings_file}...") logging.info(f"Loading settings from {settings_file}...")
new_settings = json.loads(open(settings_file, 'r').read()) new_settings = json.loads(open(settings_file, 'r').read())
for item in new_settings: for item in new_settings:
shared.settings[item] = new_settings[item] shared.settings[item] = new_settings[item]
@ -891,7 +892,7 @@ if __name__ == "__main__":
# Select the model from a command-line menu # Select the model from a command-line menu
elif shared.args.model_menu: elif shared.args.model_menu:
if len(available_models) == 0: if len(available_models) == 0:
print('No models are available! Please download at least one.') logging.error('No models are available! Please download at least one.')
sys.exit(0) sys.exit(0)
else: else:
print('The following models are available:\n') print('The following models are available:\n')