mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Better warning messages
This commit is contained in:
parent
0a48b29cd8
commit
95d04d6a8d
@ -1,4 +1,5 @@
|
|||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -71,7 +72,6 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||||||
|
|
||||||
del layers
|
del layers
|
||||||
|
|
||||||
print('Loading model ...')
|
|
||||||
if checkpoint.endswith('.safetensors'):
|
if checkpoint.endswith('.safetensors'):
|
||||||
from safetensors.torch import load_file as safe_load
|
from safetensors.torch import load_file as safe_load
|
||||||
model.load_state_dict(safe_load(checkpoint), strict=False)
|
model.load_state_dict(safe_load(checkpoint), strict=False)
|
||||||
@ -90,8 +90,6 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||||||
quant.autotune_warmup_fused(model)
|
quant.autotune_warmup_fused(model)
|
||||||
|
|
||||||
model.seqlen = 2048
|
model.seqlen = 2048
|
||||||
print('Done.')
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -119,11 +117,13 @@ def find_quantized_model_file(model_name):
|
|||||||
|
|
||||||
if len(found_pts) > 0:
|
if len(found_pts) > 0:
|
||||||
if len(found_pts) > 1:
|
if len(found_pts) > 1:
|
||||||
print('Warning: more than one .pt model has been found. The last one will be selected. It could be wrong.')
|
logging.warning('More than one .pt model has been found. The last one will be selected. It could be wrong.')
|
||||||
|
|
||||||
pt_path = found_pts[-1]
|
pt_path = found_pts[-1]
|
||||||
elif len(found_safetensors) > 0:
|
elif len(found_safetensors) > 0:
|
||||||
if len(found_pts) > 1:
|
if len(found_pts) > 1:
|
||||||
print('Warning: more than one .safetensors model has been found. The last one will be selected. It could be wrong.')
|
logging.warning('More than one .safetensors model has been found. The last one will be selected. It could be wrong.')
|
||||||
|
|
||||||
pt_path = found_safetensors[-1]
|
pt_path = found_safetensors[-1]
|
||||||
|
|
||||||
return pt_path
|
return pt_path
|
||||||
@ -142,8 +142,7 @@ def load_quantized(model_name):
|
|||||||
elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])):
|
elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])):
|
||||||
model_type = 'gptj'
|
model_type = 'gptj'
|
||||||
else:
|
else:
|
||||||
print("Can't determine model type from model name. Please specify it manually using --model_type "
|
logging.error("Can't determine model type from model name. Please specify it manually using --model_type argument")
|
||||||
"argument")
|
|
||||||
exit()
|
exit()
|
||||||
else:
|
else:
|
||||||
model_type = shared.args.model_type.lower()
|
model_type = shared.args.model_type.lower()
|
||||||
@ -153,20 +152,21 @@ def load_quantized(model_name):
|
|||||||
load_quant = llama_inference_offload.load_quant
|
load_quant = llama_inference_offload.load_quant
|
||||||
elif model_type in ('llama', 'opt', 'gptj'):
|
elif model_type in ('llama', 'opt', 'gptj'):
|
||||||
if shared.args.pre_layer:
|
if shared.args.pre_layer:
|
||||||
print("Warning: ignoring --pre_layer because it only works for llama model type.")
|
logging.warning("Ignoring --pre_layer because it only works for llama model type.")
|
||||||
|
|
||||||
load_quant = _load_quant
|
load_quant = _load_quant
|
||||||
else:
|
else:
|
||||||
print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
|
logging.error("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
# Find the quantized model weights file (.pt/.safetensors)
|
# Find the quantized model weights file (.pt/.safetensors)
|
||||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||||
pt_path = find_quantized_model_file(model_name)
|
pt_path = find_quantized_model_file(model_name)
|
||||||
if not pt_path:
|
if not pt_path:
|
||||||
print("Could not find the quantized model in .pt or .safetensors format, exiting...")
|
logging.error("Could not find the quantized model in .pt or .safetensors format, exiting...")
|
||||||
exit()
|
exit()
|
||||||
else:
|
else:
|
||||||
print(f"Found the following quantized model: {pt_path}")
|
logging.info(f"Found the following quantized model: {pt_path}")
|
||||||
|
|
||||||
# qwopqwop200's offload
|
# qwopqwop200's offload
|
||||||
if model_type == 'llama' and shared.args.pre_layer:
|
if model_type == 'llama' and shared.args.pre_layer:
|
||||||
@ -188,7 +188,7 @@ def load_quantized(model_name):
|
|||||||
max_memory = accelerate.utils.get_balanced_memory(model)
|
max_memory = accelerate.utils.get_balanced_memory(model)
|
||||||
|
|
||||||
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
|
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
|
||||||
print("Using the following device map for the quantized model:", device_map)
|
logging.info("Using the following device map for the quantized model:", device_map)
|
||||||
# https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
|
# https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
|
||||||
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)
|
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -18,7 +19,7 @@ def add_lora_to_model(lora_names):
|
|||||||
|
|
||||||
# Add a LoRA when another LoRA is already present
|
# Add a LoRA when another LoRA is already present
|
||||||
if len(removed_set) == 0 and len(prior_set) > 0:
|
if len(removed_set) == 0 and len(prior_set) > 0:
|
||||||
print(f"Adding the LoRA(s) named {added_set} to the model...")
|
logging.info(f"Adding the LoRA(s) named {added_set} to the model...")
|
||||||
for lora in added_set:
|
for lora in added_set:
|
||||||
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
|
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
|
||||||
|
|
||||||
@ -29,7 +30,7 @@ def add_lora_to_model(lora_names):
|
|||||||
shared.model.disable_adapter()
|
shared.model.disable_adapter()
|
||||||
|
|
||||||
if len(lora_names) > 0:
|
if len(lora_names) > 0:
|
||||||
print("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
|
logging.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
|
||||||
params = {}
|
params = {}
|
||||||
if not shared.args.cpu:
|
if not shared.args.cpu:
|
||||||
params['dtype'] = shared.model.dtype
|
params['dtype'] = shared.model.dtype
|
||||||
|
@ -3,6 +3,7 @@ import base64
|
|||||||
import copy
|
import copy
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -138,7 +139,7 @@ def extract_message_from_reply(reply, state):
|
|||||||
|
|
||||||
def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
||||||
if shared.model_name == 'None' or shared.model is None:
|
if shared.model_name == 'None' or shared.model is None:
|
||||||
print("No model is loaded! Select one in the Model tab.")
|
logging.error("No model is loaded! Select one in the Model tab.")
|
||||||
yield shared.history['visible']
|
yield shared.history['visible']
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -216,7 +217,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
|||||||
|
|
||||||
def impersonate_wrapper(text, state):
|
def impersonate_wrapper(text, state):
|
||||||
if shared.model_name == 'None' or shared.model is None:
|
if shared.model_name == 'None' or shared.model is None:
|
||||||
print("No model is loaded! Select one in the Model tab.")
|
logging.error("No model is loaded! Select one in the Model tab.")
|
||||||
yield ''
|
yield ''
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -523,7 +524,7 @@ def upload_character(json_file, img, tavern=False):
|
|||||||
img = Image.open(io.BytesIO(img))
|
img = Image.open(io.BytesIO(img))
|
||||||
img.save(Path(f'characters/{outfile_name}.png'))
|
img.save(Path(f'characters/{outfile_name}.png'))
|
||||||
|
|
||||||
print(f'New character saved to "characters/{outfile_name}.json".')
|
logging.info(f'New character saved to "characters/{outfile_name}.json".')
|
||||||
return outfile_name
|
return outfile_name
|
||||||
|
|
||||||
|
|
||||||
@ -547,6 +548,6 @@ def upload_your_profile_picture(img, name1, name2, mode):
|
|||||||
else:
|
else:
|
||||||
img = make_thumbnail(img)
|
img = make_thumbnail(img)
|
||||||
img.save(Path('cache/pfp_me.png'))
|
img.save(Path('cache/pfp_me.png'))
|
||||||
print('Profile picture saved to "cache/pfp_me.png"')
|
logging.info('Profile picture saved to "cache/pfp_me.png"')
|
||||||
|
|
||||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
@ -28,7 +29,7 @@ def load_extensions():
|
|||||||
for i, name in enumerate(shared.args.extensions):
|
for i, name in enumerate(shared.args.extensions):
|
||||||
if name in available_extensions:
|
if name in available_extensions:
|
||||||
if name != 'api':
|
if name != 'api':
|
||||||
print(f'Loading the extension "{name}"... ', end='')
|
logging.info(f'Loading the extension "{name}"...')
|
||||||
try:
|
try:
|
||||||
exec(f"import extensions.{name}.script")
|
exec(f"import extensions.{name}.script")
|
||||||
extension = getattr(extensions, name).script
|
extension = getattr(extensions, name).script
|
||||||
@ -38,12 +39,8 @@ def load_extensions():
|
|||||||
extension.setup()
|
extension.setup()
|
||||||
|
|
||||||
state[name] = [True, i]
|
state[name] = [True, i]
|
||||||
if name != 'api':
|
|
||||||
print('Ok.')
|
|
||||||
except:
|
except:
|
||||||
if name != 'api':
|
logging.error('Failed to load the extension "{name}".')
|
||||||
print('Fail.')
|
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,29 +1,28 @@
|
|||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import transformers.models.llama.modeling_llama
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
|
||||||
|
|
||||||
if shared.args.xformers:
|
if shared.args.xformers:
|
||||||
try:
|
try:
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
except Exception:
|
except Exception:
|
||||||
print("🔴 xformers not found! Please install it before trying to use it.", file=sys.stderr)
|
logging.error("xformers not found! Please install it before trying to use it.", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
def hijack_llama_attention():
|
def hijack_llama_attention():
|
||||||
if shared.args.xformers:
|
if shared.args.xformers:
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
||||||
print("Replaced attention with xformers_attention")
|
logging.info("Replaced attention with xformers_attention")
|
||||||
elif shared.args.sdp_attention:
|
elif shared.args.sdp_attention:
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
|
||||||
print("Replaced attention with sdp_attention")
|
logging.info("Replaced attention with sdp_attention")
|
||||||
|
|
||||||
|
|
||||||
def xformers_forward(
|
def xformers_forward(
|
||||||
@ -55,16 +54,14 @@ def xformers_forward(
|
|||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
#We only apply xformers optimizations if we don't need to output the whole attention matrix
|
# We only apply xformers optimizations if we don't need to output the whole attention matrix
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
dtype = query_states.dtype
|
|
||||||
|
|
||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
key_states = key_states.transpose(1, 2)
|
key_states = key_states.transpose(1, 2)
|
||||||
value_states = value_states.transpose(1, 2)
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
#This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
||||||
#We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
||||||
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||||
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
|
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
|
||||||
@ -102,9 +99,7 @@ def xformers_forward(
|
|||||||
attn_output = attn_output.transpose(1, 2)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
@ -137,7 +132,7 @@ def sdp_attention_forward(
|
|||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
#We only apply sdp attention if we don't need to output the whole attention matrix
|
# We only apply sdp attention if we don't need to output the whole attention matrix
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
|
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
109
modules/logging_colors.py
Normal file
109
modules/logging_colors.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# encoding: utf-8
|
||||||
|
import logging
|
||||||
|
# now we patch Python code to add color support to logging.StreamHandler
|
||||||
|
|
||||||
|
|
||||||
|
def add_coloring_to_emit_windows(fn):
|
||||||
|
# add methods we need to the class
|
||||||
|
def _out_handle(self):
|
||||||
|
import ctypes
|
||||||
|
return ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
|
||||||
|
out_handle = property(_out_handle)
|
||||||
|
|
||||||
|
def _set_color(self, code):
|
||||||
|
import ctypes
|
||||||
|
# Constants from the Windows API
|
||||||
|
self.STD_OUTPUT_HANDLE = -11
|
||||||
|
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
|
||||||
|
ctypes.windll.kernel32.SetConsoleTextAttribute(hdl, code)
|
||||||
|
|
||||||
|
setattr(logging.StreamHandler, '_set_color', _set_color)
|
||||||
|
|
||||||
|
def new(*args):
|
||||||
|
FOREGROUND_BLUE = 0x0001 # text color contains blue.
|
||||||
|
FOREGROUND_GREEN = 0x0002 # text color contains green.
|
||||||
|
FOREGROUND_RED = 0x0004 # text color contains red.
|
||||||
|
FOREGROUND_INTENSITY = 0x0008 # text color is intensified.
|
||||||
|
FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED
|
||||||
|
# winbase.h
|
||||||
|
# STD_INPUT_HANDLE = -10
|
||||||
|
# STD_OUTPUT_HANDLE = -11
|
||||||
|
# STD_ERROR_HANDLE = -12
|
||||||
|
|
||||||
|
# wincon.h
|
||||||
|
# FOREGROUND_BLACK = 0x0000
|
||||||
|
FOREGROUND_BLUE = 0x0001
|
||||||
|
FOREGROUND_GREEN = 0x0002
|
||||||
|
# FOREGROUND_CYAN = 0x0003
|
||||||
|
FOREGROUND_RED = 0x0004
|
||||||
|
FOREGROUND_MAGENTA = 0x0005
|
||||||
|
FOREGROUND_YELLOW = 0x0006
|
||||||
|
# FOREGROUND_GREY = 0x0007
|
||||||
|
FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified.
|
||||||
|
|
||||||
|
# BACKGROUND_BLACK = 0x0000
|
||||||
|
# BACKGROUND_BLUE = 0x0010
|
||||||
|
# BACKGROUND_GREEN = 0x0020
|
||||||
|
# BACKGROUND_CYAN = 0x0030
|
||||||
|
# BACKGROUND_RED = 0x0040
|
||||||
|
# BACKGROUND_MAGENTA = 0x0050
|
||||||
|
BACKGROUND_YELLOW = 0x0060
|
||||||
|
# BACKGROUND_GREY = 0x0070
|
||||||
|
BACKGROUND_INTENSITY = 0x0080 # background color is intensified.
|
||||||
|
|
||||||
|
levelno = args[1].levelno
|
||||||
|
if (levelno >= 50):
|
||||||
|
color = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY
|
||||||
|
elif (levelno >= 40):
|
||||||
|
color = FOREGROUND_RED | FOREGROUND_INTENSITY
|
||||||
|
elif (levelno >= 30):
|
||||||
|
color = FOREGROUND_YELLOW | FOREGROUND_INTENSITY
|
||||||
|
elif (levelno >= 20):
|
||||||
|
color = FOREGROUND_GREEN
|
||||||
|
elif (levelno >= 10):
|
||||||
|
color = FOREGROUND_MAGENTA
|
||||||
|
else:
|
||||||
|
color = FOREGROUND_WHITE
|
||||||
|
args[0]._set_color(color)
|
||||||
|
|
||||||
|
ret = fn(*args)
|
||||||
|
args[0]._set_color(FOREGROUND_WHITE)
|
||||||
|
# print "after"
|
||||||
|
return ret
|
||||||
|
return new
|
||||||
|
|
||||||
|
|
||||||
|
def add_coloring_to_emit_ansi(fn):
|
||||||
|
# add methods we need to the class
|
||||||
|
def new(*args):
|
||||||
|
levelno = args[1].levelno
|
||||||
|
if (levelno >= 50):
|
||||||
|
color = '\x1b[31m' # red
|
||||||
|
elif (levelno >= 40):
|
||||||
|
color = '\x1b[31m' # red
|
||||||
|
elif (levelno >= 30):
|
||||||
|
color = '\x1b[33m' # yellow
|
||||||
|
elif (levelno >= 20):
|
||||||
|
color = '\x1b[32m' # green
|
||||||
|
elif (levelno >= 10):
|
||||||
|
color = '\x1b[35m' # pink
|
||||||
|
else:
|
||||||
|
color = '\x1b[0m' # normal
|
||||||
|
args[1].msg = color + args[1].msg + '\x1b[0m' # normal
|
||||||
|
# print "after"
|
||||||
|
return fn(*args)
|
||||||
|
return new
|
||||||
|
|
||||||
|
|
||||||
|
import platform
|
||||||
|
if platform.system() == 'Windows':
|
||||||
|
# Windows does not support ANSI escapes and we are using API calls to set the console color
|
||||||
|
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
|
||||||
|
else:
|
||||||
|
# all non-Windows platforms are supporting ANSI escapes so we use them
|
||||||
|
logging.StreamHandler.emit = add_coloring_to_emit_ansi(logging.StreamHandler.emit)
|
||||||
|
# log = logging.getLogger()
|
||||||
|
# log.addFilter(log_filter())
|
||||||
|
# //hdlr = logging.StreamHandler()
|
||||||
|
# //hdlr.setFormatter(formatter())
|
@ -1,5 +1,6 @@
|
|||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
@ -65,7 +66,7 @@ def find_model_type(model_name):
|
|||||||
|
|
||||||
|
|
||||||
def load_model(model_name):
|
def load_model(model_name):
|
||||||
print(f"Loading {model_name}...")
|
logging.info(f"Loading {model_name}...")
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
shared.model_type = find_model_type(model_name)
|
shared.model_type = find_model_type(model_name)
|
||||||
@ -116,7 +117,7 @@ def load_model(model_name):
|
|||||||
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
|
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
|
||||||
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
|
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
|
||||||
model.module.eval() # Inference
|
model.module.eval() # Inference
|
||||||
print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
|
logging.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
|
||||||
|
|
||||||
# RMKV model (not on HuggingFace)
|
# RMKV model (not on HuggingFace)
|
||||||
elif shared.model_type == 'rwkv':
|
elif shared.model_type == 'rwkv':
|
||||||
@ -137,7 +138,7 @@ def load_model(model_name):
|
|||||||
else:
|
else:
|
||||||
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
|
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
|
||||||
|
|
||||||
print(f"llama.cpp weights detected: {model_file}\n")
|
logging.info(f"llama.cpp weights detected: {model_file}\n")
|
||||||
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
|
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
@ -146,7 +147,7 @@ def load_model(model_name):
|
|||||||
|
|
||||||
# Monkey patch
|
# Monkey patch
|
||||||
if shared.args.monkey_patch:
|
if shared.args.monkey_patch:
|
||||||
print("Warning: applying the monkey patch for using LoRAs in 4-bit mode.\nIt may cause undefined behavior outside its intended scope.")
|
logging.warning("Warning: applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope.")
|
||||||
from modules.monkey_patch_gptq_lora import load_model_llama
|
from modules.monkey_patch_gptq_lora import load_model_llama
|
||||||
|
|
||||||
model, _ = load_model_llama(model_name)
|
model, _ = load_model_llama(model_name)
|
||||||
@ -161,7 +162,7 @@ def load_model(model_name):
|
|||||||
else:
|
else:
|
||||||
params = {"low_cpu_mem_usage": True}
|
params = {"low_cpu_mem_usage": True}
|
||||||
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
|
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
|
||||||
print("Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
|
logging.warning("Warning: torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.")
|
||||||
shared.args.cpu = True
|
shared.args.cpu = True
|
||||||
|
|
||||||
if shared.args.cpu:
|
if shared.args.cpu:
|
||||||
@ -184,6 +185,7 @@ def load_model(model_name):
|
|||||||
max_memory = {}
|
max_memory = {}
|
||||||
for i in range(len(memory_map)):
|
for i in range(len(memory_map)):
|
||||||
max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
|
max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
|
||||||
|
|
||||||
max_memory['cpu'] = max_cpu_memory
|
max_memory['cpu'] = max_cpu_memory
|
||||||
params['max_memory'] = max_memory
|
params['max_memory'] = max_memory
|
||||||
elif shared.args.auto_devices:
|
elif shared.args.auto_devices:
|
||||||
@ -191,9 +193,9 @@ def load_model(model_name):
|
|||||||
suggestion = round((total_mem - 1000) / 1000) * 1000
|
suggestion = round((total_mem - 1000) / 1000) * 1000
|
||||||
if total_mem - suggestion < 800:
|
if total_mem - suggestion < 800:
|
||||||
suggestion -= 1000
|
suggestion -= 1000
|
||||||
suggestion = int(round(suggestion / 1000))
|
|
||||||
print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
|
|
||||||
|
|
||||||
|
suggestion = int(round(suggestion / 1000))
|
||||||
|
logging.warning(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
|
||||||
max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
|
max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
|
||||||
params['max_memory'] = max_memory
|
params['max_memory'] = max_memory
|
||||||
|
|
||||||
@ -201,11 +203,11 @@ def load_model(model_name):
|
|||||||
params["offload_folder"] = shared.args.disk_cache_dir
|
params["offload_folder"] = shared.args.disk_cache_dir
|
||||||
|
|
||||||
checkpoint = Path(f'{shared.args.model_dir}/{model_name}')
|
checkpoint = Path(f'{shared.args.model_dir}/{model_name}')
|
||||||
|
|
||||||
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
|
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
|
||||||
config = AutoConfig.from_pretrained(checkpoint)
|
config = AutoConfig.from_pretrained(checkpoint)
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = LoaderClass.from_config(config)
|
model = LoaderClass.from_config(config)
|
||||||
|
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
params['device_map'] = infer_auto_device_map(
|
params['device_map'] = infer_auto_device_map(
|
||||||
model,
|
model,
|
||||||
@ -230,7 +232,7 @@ def load_model(model_name):
|
|||||||
if shared.model_type != 'llava':
|
if shared.model_type != 'llava':
|
||||||
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
||||||
if p.exists():
|
if p.exists():
|
||||||
print(f"Loading the universal LLaMA tokenizer from {p}...")
|
logging.info(f"Loading the universal LLaMA tokenizer from {p}...")
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
|
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -247,7 +249,7 @@ def load_model(model_name):
|
|||||||
else:
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=trust_remote_code)
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=trust_remote_code)
|
||||||
|
|
||||||
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
logging.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
@ -276,20 +278,20 @@ def load_soft_prompt(name):
|
|||||||
zf.extract('tensor.npy')
|
zf.extract('tensor.npy')
|
||||||
zf.extract('meta.json')
|
zf.extract('meta.json')
|
||||||
j = json.loads(open('meta.json', 'r').read())
|
j = json.loads(open('meta.json', 'r').read())
|
||||||
print(f"\nLoading the softprompt \"{name}\".")
|
logging.info(f"\nLoading the softprompt \"{name}\".")
|
||||||
for field in j:
|
for field in j:
|
||||||
if field != 'name':
|
if field != 'name':
|
||||||
if type(j[field]) is list:
|
if type(j[field]) is list:
|
||||||
print(f"{field}: {', '.join(j[field])}")
|
logging.info(f"{field}: {', '.join(j[field])}")
|
||||||
else:
|
else:
|
||||||
print(f"{field}: {j[field]}")
|
logging.info(f"{field}: {j[field]}")
|
||||||
print()
|
logging.info()
|
||||||
tensor = np.load('tensor.npy')
|
tensor = np.load('tensor.npy')
|
||||||
Path('tensor.npy').unlink()
|
Path('tensor.npy').unlink()
|
||||||
Path('meta.json').unlink()
|
Path('meta.json').unlink()
|
||||||
|
|
||||||
tensor = torch.Tensor(tensor).to(device=shared.model.device, dtype=shared.model.dtype)
|
tensor = torch.Tensor(tensor).to(device=shared.model.device, dtype=shared.model.dtype)
|
||||||
tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
|
tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
|
||||||
|
|
||||||
shared.soft_prompt = True
|
shared.soft_prompt = True
|
||||||
shared.soft_prompt_tensor = tensor
|
shared.soft_prompt_tensor = tensor
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ from modules.GPTQ_loader import find_quantized_model_file
|
|||||||
|
|
||||||
replace_peft_model_with_gptq_lora_model()
|
replace_peft_model_with_gptq_lora_model()
|
||||||
|
|
||||||
|
|
||||||
def load_model_llama(model_name):
|
def load_model_llama(model_name):
|
||||||
config_path = str(Path(f'{shared.args.model_dir}/{model_name}'))
|
config_path = str(Path(f'{shared.args.model_dir}/{model_name}'))
|
||||||
model_path = str(find_quantized_model_file(model_name))
|
model_path = str(find_quantized_model_file(model_name))
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
@ -170,19 +171,19 @@ args_defaults = parser.parse_args([])
|
|||||||
deprecated_dict = {}
|
deprecated_dict = {}
|
||||||
for k in deprecated_dict:
|
for k in deprecated_dict:
|
||||||
if getattr(args, k) != deprecated_dict[k][1]:
|
if getattr(args, k) != deprecated_dict[k][1]:
|
||||||
print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.\n")
|
logging.warning(f"--{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
|
||||||
setattr(args, deprecated_dict[k][0], getattr(args, k))
|
setattr(args, deprecated_dict[k][0], getattr(args, k))
|
||||||
|
|
||||||
# Deprecation warnings for parameters that have been removed
|
# Deprecation warnings for parameters that have been removed
|
||||||
if args.cai_chat:
|
if args.cai_chat:
|
||||||
print("Warning: --cai-chat is deprecated. Use --chat instead.\n")
|
logging.warning("--cai-chat is deprecated. Use --chat instead.")
|
||||||
args.chat = True
|
args.chat = True
|
||||||
|
|
||||||
# Security warnings
|
# Security warnings
|
||||||
if args.trust_remote_code:
|
if args.trust_remote_code:
|
||||||
print("Warning: trust_remote_code is enabled. This is dangerous.\n")
|
logging.warning("trust_remote_code is enabled. This is dangerous.")
|
||||||
if args.share:
|
if args.share:
|
||||||
print("Warning: the gradio \"share link\" feature downloads a proprietary and\nunaudited blob to create a reverse tunnel. This is potentially dangerous.\n")
|
logging.warning("The gradio \"share link\" feature downloads a proprietary and unaudited blob to create a reverse tunnel. This is potentially dangerous.")
|
||||||
|
|
||||||
# Activating the API extension
|
# Activating the API extension
|
||||||
if args.api or args.public_api:
|
if args.api or args.public_api:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import ast
|
import ast
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
@ -175,7 +176,7 @@ def get_generate_params(state):
|
|||||||
|
|
||||||
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||||
if shared.model_name == 'None' or shared.model is None:
|
if shared.model_name == 'None' or shared.model is None:
|
||||||
print("No model is loaded! Select one in the Model tab.")
|
logging.error("No model is loaded! Select one in the Model tab.")
|
||||||
yield formatted_outputs(question, shared.model_name)
|
yield formatted_outputs(question, shared.model_name)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
@ -40,7 +41,6 @@ WANT_INTERRUPT = False
|
|||||||
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer"]
|
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer"]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_datasets(path: str, ext: str):
|
def get_datasets(path: str, ext: str):
|
||||||
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
|
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
|
||||||
|
|
||||||
@ -123,7 +123,7 @@ def create_train_interface():
|
|||||||
stop_evaluation = gr.Button("Interrupt")
|
stop_evaluation = gr.Button("Interrupt")
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
evaluation_log = gr.Markdown(value = '')
|
evaluation_log = gr.Markdown(value='')
|
||||||
|
|
||||||
evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True)
|
evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True)
|
||||||
save_comments = gr.Button('Save comments')
|
save_comments = gr.Button('Save comments')
|
||||||
@ -220,13 +220,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
if model_type == "PeftModelForCausalLM":
|
if model_type == "PeftModelForCausalLM":
|
||||||
if len(shared.args.lora_names) > 0:
|
if len(shared.args.lora_names) > 0:
|
||||||
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||||
print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.")
|
logging.warning("Training LoRA over top of another LoRA. May have unexpected effects.")
|
||||||
else:
|
else:
|
||||||
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||||
print("Warning: Model ID not matched due to LoRA loading. Consider reloading base model.")
|
logging.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
|
||||||
else:
|
else:
|
||||||
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||||
print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
|
logging.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
|
||||||
|
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
if shared.args.wbits > 0 and not shared.args.monkey_patch:
|
if shared.args.wbits > 0 and not shared.args.monkey_patch:
|
||||||
@ -235,7 +236,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
elif not shared.args.load_in_8bit and shared.args.wbits <= 0:
|
elif not shared.args.load_in_8bit and shared.args.wbits <= 0:
|
||||||
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
|
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
|
||||||
print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
|
logging.warning("It is highly recommended you use `--load-in-8bit` for LoRA training.")
|
||||||
time.sleep(2) # Give it a moment for the message to show in UI before continuing
|
time.sleep(2) # Give it a moment for the message to show in UI before continuing
|
||||||
|
|
||||||
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
|
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
|
||||||
@ -255,7 +256,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
# == Prep the dataset, format, etc ==
|
# == Prep the dataset, format, etc ==
|
||||||
if raw_text_file not in ['None', '']:
|
if raw_text_file not in ['None', '']:
|
||||||
print("Loading raw text file dataset...")
|
logging.info("Loading raw text file dataset...")
|
||||||
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
||||||
raw_text = file.read()
|
raw_text = file.read()
|
||||||
|
|
||||||
@ -299,7 +300,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
prompt = generate_prompt(data_point)
|
prompt = generate_prompt(data_point)
|
||||||
return tokenize(prompt)
|
return tokenize(prompt)
|
||||||
|
|
||||||
print("Loading JSON datasets...")
|
logging.info("Loading JSON datasets...")
|
||||||
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
|
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
|
||||||
train_data = data['train'].map(generate_and_tokenize_prompt)
|
train_data = data['train'].map(generate_and_tokenize_prompt)
|
||||||
|
|
||||||
@ -311,10 +312,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
# == Start prepping the model itself ==
|
# == Start prepping the model itself ==
|
||||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||||
print("Getting model ready...")
|
logging.info("Getting model ready...")
|
||||||
prepare_model_for_int8_training(shared.model)
|
prepare_model_for_int8_training(shared.model)
|
||||||
|
|
||||||
print("Prepping for training...")
|
logging.info("Prepping for training...")
|
||||||
config = LoraConfig(
|
config = LoraConfig(
|
||||||
r=lora_rank,
|
r=lora_rank,
|
||||||
lora_alpha=lora_alpha,
|
lora_alpha=lora_alpha,
|
||||||
@ -325,10 +326,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print("Creating LoRA model...")
|
logging.info("Creating LoRA model...")
|
||||||
lora_model = get_peft_model(shared.model, config)
|
lora_model = get_peft_model(shared.model, config)
|
||||||
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
|
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
|
||||||
print("Loading existing LoRA data...")
|
logging.info("Loading existing LoRA data...")
|
||||||
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
|
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
|
||||||
set_peft_model_state_dict(lora_model, state_dict_peft)
|
set_peft_model_state_dict(lora_model, state_dict_peft)
|
||||||
except:
|
except:
|
||||||
@ -406,7 +407,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
json.dump({x: vars[x] for x in PARAMETERS}, file)
|
json.dump({x: vars[x] for x in PARAMETERS}, file)
|
||||||
|
|
||||||
# == Main run and monitor loop ==
|
# == Main run and monitor loop ==
|
||||||
print("Starting training...")
|
logging.info("Starting training...")
|
||||||
yield "Starting..."
|
yield "Starting..."
|
||||||
if WANT_INTERRUPT:
|
if WANT_INTERRUPT:
|
||||||
yield "Interrupted before start."
|
yield "Interrupted before start."
|
||||||
@ -416,7 +417,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
||||||
lora_model.save_pretrained(lora_file_path)
|
lora_model.save_pretrained(lora_file_path)
|
||||||
print("LoRA training run is completed and saved.")
|
logging.info("LoRA training run is completed and saved.")
|
||||||
tracked.did_save = True
|
tracked.did_save = True
|
||||||
|
|
||||||
thread = threading.Thread(target=threaded_run)
|
thread = threading.Thread(target=threaded_run)
|
||||||
@ -448,14 +449,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
# Saving in the train thread might fail if an error occurs, so save here if so.
|
# Saving in the train thread might fail if an error occurs, so save here if so.
|
||||||
if not tracked.did_save:
|
if not tracked.did_save:
|
||||||
print("Training complete, saving...")
|
logging.info("Training complete, saving...")
|
||||||
lora_model.save_pretrained(lora_file_path)
|
lora_model.save_pretrained(lora_file_path)
|
||||||
|
|
||||||
if WANT_INTERRUPT:
|
if WANT_INTERRUPT:
|
||||||
print("Training interrupted.")
|
logging.info("Training interrupted.")
|
||||||
yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`"
|
yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`"
|
||||||
else:
|
else:
|
||||||
print("Training complete!")
|
logging.info("Training complete!")
|
||||||
yield f"Done! LoRA saved to `{lora_file_path}`"
|
yield f"Done! LoRA saved to `{lora_file_path}`"
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ theme = gr.themes.Default(
|
|||||||
background_fill_secondary='#eaeaea'
|
background_fill_secondary='#eaeaea'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def list_model_elements():
|
def list_model_elements():
|
||||||
elements = ['cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'wbits', 'groupsize', 'model_type', 'pre_layer']
|
elements = ['cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'wbits', 'groupsize', 'model_type', 'pre_layer']
|
||||||
for i in range(torch.cuda.device_count()):
|
for i in range(torch.cuda.device_count()):
|
||||||
|
13
server.py
13
server.py
@ -1,14 +1,17 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import requests
|
import requests
|
||||||
import warnings
|
import warnings
|
||||||
|
import modules.logging_colors
|
||||||
|
|
||||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||||
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
||||||
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
||||||
|
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)
|
||||||
|
|
||||||
# This is a hack to prevent Gradio from phoning home when it gets imported
|
# This is a hack to prevent Gradio from phoning home when it gets imported
|
||||||
def my_get(url, **kwargs):
|
def my_get(url, **kwargs):
|
||||||
print('Gradio HTTP request redirected to localhost :)')
|
logging.info('Gradio HTTP request redirected to localhost :)')
|
||||||
kwargs.setdefault('allow_redirects', True)
|
kwargs.setdefault('allow_redirects', True)
|
||||||
return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
|
return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
|
||||||
|
|
||||||
@ -17,9 +20,8 @@ requests.get = my_get
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
requests.get = original_get
|
requests.get = original_get
|
||||||
|
|
||||||
# This fixes LaTeX rendering on some systems
|
|
||||||
import matplotlib
|
import matplotlib
|
||||||
matplotlib.use('Agg')
|
matplotlib.use('Agg') # This fixes LaTeX rendering on some systems
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import io
|
import io
|
||||||
@ -39,7 +41,6 @@ import psutil
|
|||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import modules.extensions as extensions_module
|
import modules.extensions as extensions_module
|
||||||
from modules import chat, shared, training, ui
|
from modules import chat, shared, training, ui
|
||||||
from modules.html_generator import chat_html_wrapper
|
from modules.html_generator import chat_html_wrapper
|
||||||
@ -860,7 +861,7 @@ if __name__ == "__main__":
|
|||||||
elif Path('settings.json').exists():
|
elif Path('settings.json').exists():
|
||||||
settings_file = Path('settings.json')
|
settings_file = Path('settings.json')
|
||||||
if settings_file is not None:
|
if settings_file is not None:
|
||||||
print(f"Loading settings from {settings_file}...")
|
logging.info(f"Loading settings from {settings_file}...")
|
||||||
new_settings = json.loads(open(settings_file, 'r').read())
|
new_settings = json.loads(open(settings_file, 'r').read())
|
||||||
for item in new_settings:
|
for item in new_settings:
|
||||||
shared.settings[item] = new_settings[item]
|
shared.settings[item] = new_settings[item]
|
||||||
@ -891,7 +892,7 @@ if __name__ == "__main__":
|
|||||||
# Select the model from a command-line menu
|
# Select the model from a command-line menu
|
||||||
elif shared.args.model_menu:
|
elif shared.args.model_menu:
|
||||||
if len(available_models) == 0:
|
if len(available_models) == 0:
|
||||||
print('No models are available! Please download at least one.')
|
logging.error('No models are available! Please download at least one.')
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
else:
|
else:
|
||||||
print('The following models are available:\n')
|
print('The following models are available:\n')
|
||||||
|
Loading…
Reference in New Issue
Block a user