mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Reorganize model loading UI completely (#2720)
This commit is contained in:
parent
57be2eecdf
commit
7ef6a50e84
1
.gitignore
vendored
1
.gitignore
vendored
@ -8,6 +8,7 @@ extensions/multimodal/pipelines
|
|||||||
logs
|
logs
|
||||||
loras
|
loras
|
||||||
models
|
models
|
||||||
|
presets
|
||||||
repositories
|
repositories
|
||||||
softprompts
|
softprompts
|
||||||
torch-dumps
|
torch-dumps
|
||||||
|
@ -211,6 +211,12 @@ Optionally, you can use the following command-line flags:
|
|||||||
| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
|
| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
|
||||||
| `--verbose` | Print the prompts to the terminal. |
|
| `--verbose` | Print the prompts to the terminal. |
|
||||||
|
|
||||||
|
#### Model loader
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
|--------------------------------------------|-------------|
|
||||||
|
| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: autogptq, gptq-for-llama, transformers, llamacpp, rwkv, flexgen |
|
||||||
|
|
||||||
#### Accelerate/transformers
|
#### Accelerate/transformers
|
||||||
|
|
||||||
| Flag | Description |
|
| Flag | Description |
|
||||||
@ -265,7 +271,6 @@ Optionally, you can use the following command-line flags:
|
|||||||
|
|
||||||
| Flag | Description |
|
| Flag | Description |
|
||||||
|---------------------------|-------------|
|
|---------------------------|-------------|
|
||||||
| `--gptq-for-llama` | Use GPTQ-for-LLaMa to load the GPTQ model instead of AutoGPTQ. |
|
|
||||||
| `--wbits WBITS` | Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. |
|
| `--wbits WBITS` | Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. |
|
||||||
| `--model_type MODEL_TYPE` | Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. |
|
| `--model_type MODEL_TYPE` | Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. |
|
||||||
| `--groupsize GROUPSIZE` | Group size. |
|
| `--groupsize GROUPSIZE` | Group size. |
|
||||||
@ -280,7 +285,6 @@ Optionally, you can use the following command-line flags:
|
|||||||
|
|
||||||
| Flag | Description |
|
| Flag | Description |
|
||||||
|------------------|-------------|
|
|------------------|-------------|
|
||||||
| `--flexgen` | Enable the use of FlexGen offloading. |
|
|
||||||
| `--percent PERCENT [PERCENT ...]` | FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). |
|
| `--percent PERCENT [PERCENT ...]` | FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). |
|
||||||
| `--compress-weight` | FlexGen: Whether to compress weight (default: False).|
|
| `--compress-weight` | FlexGen: Whether to compress weight (default: False).|
|
||||||
| `--pin-weight [PIN_WEIGHT]` | FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%). |
|
| `--pin-weight [PIN_WEIGHT]` | FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%). |
|
||||||
|
@ -21,13 +21,13 @@ The output will be saved to `models/opt-1.3b-np/`.
|
|||||||
The basic command is the following:
|
The basic command is the following:
|
||||||
|
|
||||||
```
|
```
|
||||||
python server.py --model opt-1.3b --flexgen
|
python server.py --model opt-1.3b --loader flexgen
|
||||||
```
|
```
|
||||||
|
|
||||||
For large models, the RAM usage may be too high and your computer may freeze. If that happens, you can try this:
|
For large models, the RAM usage may be too high and your computer may freeze. If that happens, you can try this:
|
||||||
|
|
||||||
```
|
```
|
||||||
python server.py --model opt-1.3b --flexgen --compress-weight
|
python server.py --model opt-1.3b --loader flexgen --compress-weight
|
||||||
```
|
```
|
||||||
|
|
||||||
With this second command, I was able to run both OPT-6.7b and OPT-13B with **2GB VRAM**, and the speed was good in both cases.
|
With this second command, I was able to run both OPT-6.7b and OPT-13B with **2GB VRAM**, and the speed was good in both cases.
|
||||||
@ -35,7 +35,7 @@ With this second command, I was able to run both OPT-6.7b and OPT-13B with **2GB
|
|||||||
You can also manually set the offload strategy with
|
You can also manually set the offload strategy with
|
||||||
|
|
||||||
```
|
```
|
||||||
python server.py --model opt-1.3b --flexgen --percent 0 100 100 0 100 0
|
python server.py --model opt-1.3b --loader flexgen --percent 0 100 100 0 100 0
|
||||||
```
|
```
|
||||||
|
|
||||||
where the six numbers after `--percent` are:
|
where the six numbers after `--percent` are:
|
||||||
@ -55,8 +55,8 @@ You should typically only change the first two numbers. If their sum is less tha
|
|||||||
|
|
||||||
In my experiments with OPT-30B using a RTX 3090 on Linux, I have obtained these results:
|
In my experiments with OPT-30B using a RTX 3090 on Linux, I have obtained these results:
|
||||||
|
|
||||||
* `--flexgen --compress-weight --percent 0 100 100 0 100 0`: 0.99 seconds per token.
|
* `--loader flexgen --compress-weight --percent 0 100 100 0 100 0`: 0.99 seconds per token.
|
||||||
* `--flexgen --compress-weight --percent 100 0 100 0 100 0`: 0.765 seconds per token.
|
* `--loader flexgen --compress-weight --percent 100 0 100 0 100 0`: 0.765 seconds per token.
|
||||||
|
|
||||||
## Limitations
|
## Limitations
|
||||||
|
|
||||||
|
@ -7,10 +7,11 @@ from modules import shared
|
|||||||
from modules.chat import generate_chat_reply
|
from modules.chat import generate_chat_reply
|
||||||
from modules.LoRA import add_lora_to_model
|
from modules.LoRA import add_lora_to_model
|
||||||
from modules.models import load_model, unload_model
|
from modules.models import load_model, unload_model
|
||||||
|
from modules.models_settings import (get_model_settings_from_yamls,
|
||||||
|
update_model_parameters)
|
||||||
from modules.text_generation import (encode, generate_reply,
|
from modules.text_generation import (encode, generate_reply,
|
||||||
stop_everything_event)
|
stop_everything_event)
|
||||||
from modules.utils import get_available_models
|
from modules.utils import get_available_models
|
||||||
from server import get_model_specific_settings, update_model_parameters
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_info():
|
def get_model_info():
|
||||||
@ -22,6 +23,7 @@ def get_model_info():
|
|||||||
'shared.args': vars(shared.args),
|
'shared.args': vars(shared.args),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Handler(BaseHTTPRequestHandler):
|
class Handler(BaseHTTPRequestHandler):
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
if self.path == '/api/v1/model':
|
if self.path == '/api/v1/model':
|
||||||
@ -126,7 +128,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
shared.model_name = model_name
|
shared.model_name = model_name
|
||||||
unload_model()
|
unload_model()
|
||||||
|
|
||||||
model_settings = get_model_specific_settings(shared.model_name)
|
model_settings = get_model_settings_from_yamls(shared.model_name)
|
||||||
shared.settings.update(model_settings)
|
shared.settings.update(model_settings)
|
||||||
update_model_parameters(model_settings, initial=True)
|
update_model_parameters(model_settings, initial=True)
|
||||||
|
|
||||||
@ -136,10 +138,10 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
try:
|
try:
|
||||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||||
if shared.args.lora:
|
if shared.args.lora:
|
||||||
add_lora_to_model(shared.args.lora) # list
|
add_lora_to_model(shared.args.lora) # list
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
response = json.dumps({'error': { 'message': repr(e) } })
|
response = json.dumps({'error': {'message': repr(e)}})
|
||||||
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
raise e
|
raise e
|
||||||
|
@ -77,8 +77,7 @@ def add_lora_to_model(lora_names):
|
|||||||
elif shared.args.load_in_8bit:
|
elif shared.args.load_in_8bit:
|
||||||
params['device_map'] = {'': 0}
|
params['device_map'] = {'': 0}
|
||||||
|
|
||||||
shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_names[0]}"),adapter_name=lora_names[0], **params)
|
shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_names[0]}"), adapter_name=lora_names[0], **params)
|
||||||
|
|
||||||
for lora in lora_names[1:]:
|
for lora in lora_names[1:]:
|
||||||
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
|
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
|
||||||
|
|
||||||
|
@ -88,8 +88,8 @@ class RWKVModel:
|
|||||||
out, state = self.model.forward(tokens[:args.chunk_len], state)
|
out, state = self.model.forward(tokens[:args.chunk_len], state)
|
||||||
tokens = tokens[args.chunk_len:]
|
tokens = tokens[args.chunk_len:]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
begin_token= len(all_tokens)
|
begin_token = len(all_tokens)
|
||||||
last_token_posi=begin_token
|
last_token_posi = begin_token
|
||||||
# cache the model state after scanning the context
|
# cache the model state after scanning the context
|
||||||
# we don't cache the state after processing our own generated tokens because
|
# we don't cache the state after processing our own generated tokens because
|
||||||
# the output string might be post-processed arbitrarily. Therefore, what's fed into the model
|
# the output string might be post-processed arbitrarily. Therefore, what's fed into the model
|
||||||
@ -122,7 +122,7 @@ class RWKVModel:
|
|||||||
if '\ufffd' not in tmp: # is valid utf-8 string?
|
if '\ufffd' not in tmp: # is valid utf-8 string?
|
||||||
if callback:
|
if callback:
|
||||||
callback(tmp)
|
callback(tmp)
|
||||||
|
|
||||||
out_str += tmp
|
out_str += tmp
|
||||||
last_token_posi = begin_token + i + 1
|
last_token_posi = begin_token + i + 1
|
||||||
return out_str
|
return out_str
|
||||||
|
@ -8,8 +8,9 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.models import load_model, unload_model
|
from modules.models import load_model, unload_model
|
||||||
|
from modules.models_settings import (get_model_settings_from_yamls,
|
||||||
|
update_model_parameters)
|
||||||
from modules.text_generation import encode
|
from modules.text_generation import encode
|
||||||
from server import get_model_specific_settings, update_model_parameters
|
|
||||||
|
|
||||||
|
|
||||||
def load_past_evaluations():
|
def load_past_evaluations():
|
||||||
@ -66,7 +67,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
|
|||||||
if model != 'current model':
|
if model != 'current model':
|
||||||
try:
|
try:
|
||||||
yield cumulative_log + f"Loading {model}...\n\n"
|
yield cumulative_log + f"Loading {model}...\n\n"
|
||||||
model_settings = get_model_specific_settings(model)
|
model_settings = get_model_settings_from_yamls(model)
|
||||||
shared.settings.update(model_settings) # hijacking the interface defaults
|
shared.settings.update(model_settings) # hijacking the interface defaults
|
||||||
update_model_parameters(model_settings) # hijacking the command-line arguments
|
update_model_parameters(model_settings) # hijacking the command-line arguments
|
||||||
shared.model_name = model
|
shared.model_name = model
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
|
|
||||||
def clone_or_pull_repository(github_url):
|
def clone_or_pull_repository(github_url):
|
||||||
repository_folder = "extensions"
|
repository_folder = "extensions"
|
||||||
repo_name = github_url.split("/")[-1].split(".")[0]
|
repo_name = github_url.split("/")[-1].split(".")[0]
|
||||||
|
86
modules/loaders.py
Normal file
86
modules/loaders.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import functools
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
loaders_and_params = {
|
||||||
|
'AutoGPTQ': [
|
||||||
|
'triton',
|
||||||
|
'no_inject_fused_attention',
|
||||||
|
'no_inject_fused_mlp',
|
||||||
|
'wbits',
|
||||||
|
'groupsize',
|
||||||
|
'desc_act',
|
||||||
|
'gpu_memory',
|
||||||
|
'cpu_memory',
|
||||||
|
'cpu',
|
||||||
|
'disk',
|
||||||
|
'auto_devices',
|
||||||
|
'trust_remote_code',
|
||||||
|
'autogptq_info',
|
||||||
|
],
|
||||||
|
'GPTQ-for-LLaMa': [
|
||||||
|
'wbits',
|
||||||
|
'groupsize',
|
||||||
|
'model_type',
|
||||||
|
'pre_layer',
|
||||||
|
'gptq_for_llama_info',
|
||||||
|
],
|
||||||
|
'llama.cpp': [
|
||||||
|
'n_ctx',
|
||||||
|
'n_gpu_layers',
|
||||||
|
'n_batch',
|
||||||
|
'threads',
|
||||||
|
'no_mmap',
|
||||||
|
'mlock',
|
||||||
|
'llama_cpp_seed',
|
||||||
|
],
|
||||||
|
'Transformers': [
|
||||||
|
'cpu_memory',
|
||||||
|
'gpu_memory',
|
||||||
|
'trust_remote_code',
|
||||||
|
'load_in_8bit',
|
||||||
|
'bf16',
|
||||||
|
'cpu',
|
||||||
|
'disk',
|
||||||
|
'auto_devices',
|
||||||
|
'load_in_4bit',
|
||||||
|
'use_double_quant',
|
||||||
|
'quant_type',
|
||||||
|
'compute_dtype',
|
||||||
|
'trust_remote_code',
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpu_memory_keys():
|
||||||
|
return [k for k in shared.gradio if k.startswith('gpu_memory')]
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def get_all_params():
|
||||||
|
all_params = set()
|
||||||
|
for k in loaders_and_params:
|
||||||
|
for el in loaders_and_params[k]:
|
||||||
|
all_params.add(el)
|
||||||
|
|
||||||
|
if 'gpu_memory' in all_params:
|
||||||
|
all_params.remove('gpu_memory')
|
||||||
|
for k in get_gpu_memory_keys():
|
||||||
|
all_params.add(k)
|
||||||
|
|
||||||
|
return sorted(all_params)
|
||||||
|
|
||||||
|
|
||||||
|
def make_loader_params_visible(loader):
|
||||||
|
params = []
|
||||||
|
all_params = get_all_params()
|
||||||
|
if loader in loaders_and_params:
|
||||||
|
params = loaders_and_params[loader]
|
||||||
|
|
||||||
|
if 'gpu_memory' in params:
|
||||||
|
params.remove('gpu_memory')
|
||||||
|
params += get_gpu_memory_keys()
|
||||||
|
|
||||||
|
return [gr.update(visible=True) if k in params else gr.update(visible=False) for k in all_params]
|
@ -14,6 +14,7 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
|||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import llama_attn_hijack, sampler_hijack
|
from modules import llama_attn_hijack, sampler_hijack
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
from modules.models_settings import infer_loader
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
@ -36,62 +37,31 @@ if shared.args.deepspeed:
|
|||||||
sampler_hijack.hijack_samplers()
|
sampler_hijack.hijack_samplers()
|
||||||
|
|
||||||
|
|
||||||
# Some models require special treatment in various parts of the code.
|
def load_model(model_name, loader=None):
|
||||||
# This function detects those models
|
|
||||||
def find_model_type(model_name):
|
|
||||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
|
||||||
if not path_to_model.exists():
|
|
||||||
return 'None'
|
|
||||||
|
|
||||||
model_name_lower = model_name.lower()
|
|
||||||
if re.match('.*rwkv.*\.pth', model_name_lower):
|
|
||||||
return 'rwkv'
|
|
||||||
elif len(list(path_to_model.glob('*ggml*.bin'))) > 0:
|
|
||||||
return 'llamacpp'
|
|
||||||
elif re.match('.*ggml.*\.bin', model_name_lower):
|
|
||||||
return 'llamacpp'
|
|
||||||
elif 'chatglm' in model_name_lower:
|
|
||||||
return 'chatglm'
|
|
||||||
elif 'galactica' in model_name_lower:
|
|
||||||
return 'galactica'
|
|
||||||
elif 'llava' in model_name_lower:
|
|
||||||
return 'llava'
|
|
||||||
elif 'oasst' in model_name_lower:
|
|
||||||
return 'oasst'
|
|
||||||
elif any((k in model_name_lower for k in ['gpt4chan', 'gpt-4chan'])):
|
|
||||||
return 'gpt4chan'
|
|
||||||
else:
|
|
||||||
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
|
|
||||||
# Not a "catch all", but fairly accurate
|
|
||||||
if config.to_dict().get("is_encoder_decoder", False):
|
|
||||||
return 'HF_seq2seq'
|
|
||||||
else:
|
|
||||||
return 'HF_generic'
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_name):
|
|
||||||
logger.info(f"Loading {model_name}...")
|
logger.info(f"Loading {model_name}...")
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
shared.model_type = find_model_type(model_name)
|
shared.is_seq2seq = False
|
||||||
if shared.model_type == 'None':
|
load_func_map = {
|
||||||
logger.error('The path to the model does not exist. Exiting.')
|
'Transformers': huggingface_loader,
|
||||||
return None, None
|
'AutoGPTQ': AutoGPTQ_loader,
|
||||||
|
'GPTQ-for-LLaMa': GPTQ_loader,
|
||||||
|
'llama.cpp': llamacpp_loader,
|
||||||
|
'FlexGen': flexgen_loader,
|
||||||
|
'RWKV': RWKV_loader
|
||||||
|
}
|
||||||
|
|
||||||
if shared.args.gptq_for_llama:
|
if loader is None:
|
||||||
load_func = GPTQ_loader
|
if shared.args.loader is not None:
|
||||||
elif Path(f'{shared.args.model_dir}/{model_name}/quantize_config.json').exists() or shared.args.wbits > 0:
|
loader = shared.args.loader
|
||||||
load_func = AutoGPTQ_loader
|
else:
|
||||||
elif shared.model_type == 'llamacpp':
|
loader = infer_loader(model_name)
|
||||||
load_func = llamacpp_loader
|
if loader is None:
|
||||||
elif shared.model_type == 'rwkv':
|
logger.error('The path to the model does not exist. Exiting.')
|
||||||
load_func = RWKV_loader
|
return None, None
|
||||||
elif shared.args.flexgen:
|
|
||||||
load_func = flexgen_loader
|
|
||||||
else:
|
|
||||||
load_func = huggingface_loader
|
|
||||||
|
|
||||||
output = load_func(model_name)
|
shared.args.loader = loader
|
||||||
|
output = load_func_map[loader](model_name)
|
||||||
if type(output) is tuple:
|
if type(output) is tuple:
|
||||||
model, tokenizer = output
|
model, tokenizer = output
|
||||||
else:
|
else:
|
||||||
@ -111,11 +81,11 @@ def load_model(model_name):
|
|||||||
|
|
||||||
def load_tokenizer(model_name, model):
|
def load_tokenizer(model_name, model):
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
if shared.model_type == 'gpt4chan' and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
if any(s in model_name.lower() for s in ['gpt-4chan', 'gpt4chan']) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||||
elif type(model) is transformers.LlamaForCausalLM or "LlamaGPTQForCausalLM" in str(type(model)):
|
elif type(model) is transformers.LlamaForCausalLM or "LlamaGPTQForCausalLM" in str(type(model)):
|
||||||
# Try to load an universal LLaMA tokenizer
|
# Try to load an universal LLaMA tokenizer
|
||||||
if shared.model_type not in ['llava', 'oasst']:
|
if any(s in shared.model_name.lower() for s in ['llava', 'oasst']):
|
||||||
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
||||||
if p.exists():
|
if p.exists():
|
||||||
logger.info(f"Loading the universal LLaMA tokenizer from {p}...")
|
logger.info(f"Loading the universal LLaMA tokenizer from {p}...")
|
||||||
@ -140,12 +110,16 @@ def load_tokenizer(model_name, model):
|
|||||||
|
|
||||||
|
|
||||||
def huggingface_loader(model_name):
|
def huggingface_loader(model_name):
|
||||||
if shared.model_type == 'chatglm':
|
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||||
|
if 'chatglm' in model_name.lower():
|
||||||
LoaderClass = AutoModel
|
LoaderClass = AutoModel
|
||||||
elif shared.model_type == 'HF_seq2seq':
|
|
||||||
LoaderClass = AutoModelForSeq2SeqLM
|
|
||||||
else:
|
else:
|
||||||
LoaderClass = AutoModelForCausalLM
|
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
|
||||||
|
if config.to_dict().get("is_encoder_decoder", False):
|
||||||
|
LoaderClass = AutoModelForSeq2SeqLM
|
||||||
|
shared.is_seq2seq = True
|
||||||
|
else:
|
||||||
|
LoaderClass = AutoModelForCausalLM
|
||||||
|
|
||||||
# Load the model in simple 16-bit mode by default
|
# Load the model in simple 16-bit mode by default
|
||||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None]):
|
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None]):
|
||||||
|
134
modules/models_settings.py
Normal file
134
modules/models_settings.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from modules import shared, ui
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_settings_from_yamls(model):
|
||||||
|
settings = shared.model_config
|
||||||
|
model_settings = {}
|
||||||
|
for pat in settings:
|
||||||
|
if re.match(pat.lower(), model.lower()):
|
||||||
|
for k in settings[pat]:
|
||||||
|
model_settings[k] = settings[pat][k]
|
||||||
|
|
||||||
|
return model_settings
|
||||||
|
|
||||||
|
|
||||||
|
def infer_loader(model_name):
|
||||||
|
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||||
|
model_settings = get_model_settings_from_yamls(model_name)
|
||||||
|
if not path_to_model.exists():
|
||||||
|
loader = None
|
||||||
|
elif Path(f'{shared.args.model_dir}/{model_name}/quantize_config.json').exists() or ('wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0):
|
||||||
|
loader = 'AutoGPTQ'
|
||||||
|
elif len(list(path_to_model.glob('*ggml*.bin'))) > 0:
|
||||||
|
loader = 'llama.cpp'
|
||||||
|
elif re.match('.*ggml.*\.bin', model_name.lower()):
|
||||||
|
loader = 'llama.cpp'
|
||||||
|
elif re.match('.*rwkv.*\.pth', model_name.lower()):
|
||||||
|
loader = 'RWKV'
|
||||||
|
elif shared.args.flexgen:
|
||||||
|
loader = 'FlexGen'
|
||||||
|
else:
|
||||||
|
loader = 'Transformers'
|
||||||
|
|
||||||
|
return loader
|
||||||
|
|
||||||
|
|
||||||
|
# UI: update the command-line arguments based on the interface values
|
||||||
|
def update_model_parameters(state, initial=False):
|
||||||
|
elements = ui.list_model_elements() # the names of the parameters
|
||||||
|
gpu_memories = []
|
||||||
|
|
||||||
|
for i, element in enumerate(elements):
|
||||||
|
if element not in state:
|
||||||
|
continue
|
||||||
|
|
||||||
|
value = state[element]
|
||||||
|
if element.startswith('gpu_memory'):
|
||||||
|
gpu_memories.append(value)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if initial and vars(shared.args)[element] != vars(shared.args_defaults)[element]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Setting null defaults
|
||||||
|
if element in ['wbits', 'groupsize', 'model_type'] and value == 'None':
|
||||||
|
value = vars(shared.args_defaults)[element]
|
||||||
|
elif element in ['cpu_memory'] and value == 0:
|
||||||
|
value = vars(shared.args_defaults)[element]
|
||||||
|
|
||||||
|
# Making some simple conversions
|
||||||
|
if element in ['wbits', 'groupsize', 'pre_layer']:
|
||||||
|
value = int(value)
|
||||||
|
elif element == 'cpu_memory' and value is not None:
|
||||||
|
value = f"{value}MiB"
|
||||||
|
|
||||||
|
if element in ['pre_layer']:
|
||||||
|
value = [value] if value > 0 else None
|
||||||
|
|
||||||
|
setattr(shared.args, element, value)
|
||||||
|
|
||||||
|
found_positive = False
|
||||||
|
for i in gpu_memories:
|
||||||
|
if i > 0:
|
||||||
|
found_positive = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not (initial and vars(shared.args)['gpu_memory'] != vars(shared.args_defaults)['gpu_memory']):
|
||||||
|
if found_positive:
|
||||||
|
shared.args.gpu_memory = [f"{i}MiB" for i in gpu_memories]
|
||||||
|
else:
|
||||||
|
shared.args.gpu_memory = None
|
||||||
|
|
||||||
|
|
||||||
|
# UI: update the state variable with the model settings
|
||||||
|
def apply_model_settings_to_state(model, state):
|
||||||
|
model_settings = get_model_settings_from_yamls(model)
|
||||||
|
if 'loader' not in model_settings:
|
||||||
|
loader = infer_loader(model)
|
||||||
|
if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
|
||||||
|
loader = 'AutoGPTQ'
|
||||||
|
|
||||||
|
# If the user is using an alternative GPTQ loader, let them keep using it
|
||||||
|
if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'exllama']):
|
||||||
|
state['loader'] = loader
|
||||||
|
|
||||||
|
for k in model_settings:
|
||||||
|
if k in state:
|
||||||
|
state[k] = model_settings[k]
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
# Save the settings for this model to models/config-user.yaml
|
||||||
|
def save_model_settings(model, state):
|
||||||
|
if model == 'None':
|
||||||
|
yield ("Not saving the settings because no model is loaded.")
|
||||||
|
return
|
||||||
|
|
||||||
|
with Path(f'{shared.args.model_dir}/config-user.yaml') as p:
|
||||||
|
if p.exists():
|
||||||
|
user_config = yaml.safe_load(open(p, 'r').read())
|
||||||
|
else:
|
||||||
|
user_config = {}
|
||||||
|
|
||||||
|
model_regex = model + '$' # For exact matches
|
||||||
|
for _dict in [user_config, shared.model_config]:
|
||||||
|
if model_regex not in _dict:
|
||||||
|
_dict[model_regex] = {}
|
||||||
|
|
||||||
|
if model_regex not in user_config:
|
||||||
|
user_config[model_regex] = {}
|
||||||
|
|
||||||
|
for k in ui.list_model_elements():
|
||||||
|
user_config[model_regex][k] = state[k]
|
||||||
|
shared.model_config[model_regex][k] = state[k]
|
||||||
|
|
||||||
|
with open(p, 'w') as f:
|
||||||
|
f.write(yaml.dump(user_config, sort_keys=False))
|
||||||
|
|
||||||
|
yield (f"Settings for {model} saved to {p}")
|
@ -52,4 +52,3 @@ def load_preset_for_ui(name, state):
|
|||||||
def generate_preset_yaml(state):
|
def generate_preset_yaml(state):
|
||||||
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
|
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
|
||||||
return yaml.dump(data, sort_keys=False)
|
return yaml.dump(data, sort_keys=False)
|
||||||
|
|
||||||
|
@ -10,7 +10,6 @@ generation_lock = None
|
|||||||
model = None
|
model = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
model_name = "None"
|
model_name = "None"
|
||||||
model_type = None
|
|
||||||
lora_names = []
|
lora_names = []
|
||||||
|
|
||||||
# Chat variables
|
# Chat variables
|
||||||
@ -97,6 +96,9 @@ parser.add_argument('--settings', type=str, help='Load the default interface set
|
|||||||
parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
|
parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
|
||||||
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
||||||
|
|
||||||
|
# Model loader
|
||||||
|
parser.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: autogptq, gptq-for-llama, transformers, llamacpp, rwkv, flexgen')
|
||||||
|
|
||||||
# Accelerate/transformers
|
# Accelerate/transformers
|
||||||
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
|
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
|
||||||
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
|
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
|
||||||
@ -139,7 +141,7 @@ parser.add_argument('--warmup_autotune', action='store_true', help='(triton) Ena
|
|||||||
parser.add_argument('--fused_mlp', action='store_true', help='(triton) Enable fused mlp.')
|
parser.add_argument('--fused_mlp', action='store_true', help='(triton) Enable fused mlp.')
|
||||||
|
|
||||||
# AutoGPTQ
|
# AutoGPTQ
|
||||||
parser.add_argument('--gptq-for-llama', action='store_true', help='Use GPTQ-for-LLaMa to load the GPTQ model instead of AutoGPTQ.')
|
parser.add_argument('--gptq-for-llama', action='store_true', help='DEPRECATED')
|
||||||
parser.add_argument('--autogptq', action='store_true', help='DEPRECATED')
|
parser.add_argument('--autogptq', action='store_true', help='DEPRECATED')
|
||||||
parser.add_argument('--triton', action='store_true', help='Use triton.')
|
parser.add_argument('--triton', action='store_true', help='Use triton.')
|
||||||
parser.add_argument('--no_inject_fused_attention', action='store_true', help='Do not use fused attention (lowers VRAM requirements).')
|
parser.add_argument('--no_inject_fused_attention', action='store_true', help='Do not use fused attention (lowers VRAM requirements).')
|
||||||
@ -147,7 +149,7 @@ parser.add_argument('--no_inject_fused_mlp', action='store_true', help='Triton m
|
|||||||
parser.add_argument('--desc_act', action='store_true', help='For models that don\'t have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig.')
|
parser.add_argument('--desc_act', action='store_true', help='For models that don\'t have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig.')
|
||||||
|
|
||||||
# FlexGen
|
# FlexGen
|
||||||
parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.')
|
parser.add_argument('--flexgen', action='store_true', help='DEPRECATED')
|
||||||
parser.add_argument('--percent', type=int, nargs="+", default=[0, 100, 100, 0, 100, 0], help='FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0).')
|
parser.add_argument('--percent', type=int, nargs="+", default=[0, 100, 100, 0, 100, 0], help='FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0).')
|
||||||
parser.add_argument("--compress-weight", action="store_true", help="FlexGen: activate weight compression.")
|
parser.add_argument("--compress-weight", action="store_true", help="FlexGen: activate weight compression.")
|
||||||
parser.add_argument("--pin-weight", type=str2bool, nargs="?", const=True, default=True, help="FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%%).")
|
parser.add_argument("--pin-weight", type=str2bool, nargs="?", const=True, default=True, help="FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%%).")
|
||||||
@ -184,7 +186,14 @@ args_defaults = parser.parse_args([])
|
|||||||
|
|
||||||
# Deprecation warnings
|
# Deprecation warnings
|
||||||
if args.autogptq:
|
if args.autogptq:
|
||||||
logger.warning('--autogptq has been deprecated and will be removed soon. AutoGPTQ is now used by default for GPTQ models.')
|
logger.warning('--autogptq has been deprecated and will be removed soon. Use --loader autogptq instead.')
|
||||||
|
args.loader = 'autogptq'
|
||||||
|
if args.gptq_for_llama:
|
||||||
|
logger.warning('--gptq-for-llama has been deprecated and will be removed soon. Use --loader gptq-for-llama instead.')
|
||||||
|
args.loader = 'gptq-for-llama'
|
||||||
|
if args.flexgen:
|
||||||
|
logger.warning('--flexgen has been deprecated and will be removed soon. Use --loader flexgen instead.')
|
||||||
|
args.loader = 'FlexGen'
|
||||||
|
|
||||||
# Security warnings
|
# Security warnings
|
||||||
if args.trust_remote_code:
|
if args.trust_remote_code:
|
||||||
@ -193,6 +202,22 @@ if args.share:
|
|||||||
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
|
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
|
||||||
|
|
||||||
|
|
||||||
|
def fix_loader_name(name):
|
||||||
|
name = name.lower()
|
||||||
|
if name in ['llamacpp', 'llama.cpp', 'llama-cpp', 'llama cpp']:
|
||||||
|
return 'llama.cpp'
|
||||||
|
elif name in ['transformers', 'huggingface', 'hf', 'hugging_face', 'hugging face']:
|
||||||
|
return 'Transformers'
|
||||||
|
elif name in ['autogptq', 'auto-gptq', 'auto_gptq', 'auto gptq']:
|
||||||
|
return 'AutoGPTQ'
|
||||||
|
elif name in ['gptq-for-llama', 'gptqforllama', 'gptqllama', 'gptq for llama', 'gptq_for_llama']:
|
||||||
|
return 'GPTQ-for-LLaMa'
|
||||||
|
|
||||||
|
|
||||||
|
if args.loader is not None:
|
||||||
|
args.loader = fix_loader_name(args.loader)
|
||||||
|
|
||||||
|
|
||||||
def add_extension(name):
|
def add_extension(name):
|
||||||
if args.extensions is None:
|
if args.extensions is None:
|
||||||
args.extensions = [name]
|
args.extensions = [name]
|
||||||
|
@ -31,7 +31,7 @@ def get_max_prompt_length(state):
|
|||||||
|
|
||||||
|
|
||||||
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
||||||
if shared.model_type in ['rwkv', 'llamacpp']:
|
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel']:
|
||||||
input_ids = shared.tokenizer.encode(str(prompt))
|
input_ids = shared.tokenizer.encode(str(prompt))
|
||||||
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
||||||
return input_ids
|
return input_ids
|
||||||
@ -51,7 +51,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||||||
if truncation_length is not None:
|
if truncation_length is not None:
|
||||||
input_ids = input_ids[:, -truncation_length:]
|
input_ids = input_ids[:, -truncation_length:]
|
||||||
|
|
||||||
if shared.model_type in ['rwkv', 'llamacpp'] or shared.args.cpu:
|
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel'] or shared.args.cpu:
|
||||||
return input_ids
|
return input_ids
|
||||||
elif shared.args.flexgen:
|
elif shared.args.flexgen:
|
||||||
return input_ids.numpy()
|
return input_ids.numpy()
|
||||||
@ -99,7 +99,7 @@ def fix_galactica(s):
|
|||||||
|
|
||||||
|
|
||||||
def get_reply_from_output_ids(output_ids, input_ids, original_question, state, is_chat=False):
|
def get_reply_from_output_ids(output_ids, input_ids, original_question, state, is_chat=False):
|
||||||
if shared.model_type == 'HF_seq2seq':
|
if shared.is_seq2seq:
|
||||||
reply = decode(output_ids, state['skip_special_tokens'])
|
reply = decode(output_ids, state['skip_special_tokens'])
|
||||||
else:
|
else:
|
||||||
new_tokens = len(output_ids) - len(input_ids[0])
|
new_tokens = len(output_ids) - len(input_ids[0])
|
||||||
@ -117,7 +117,7 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state, i
|
|||||||
|
|
||||||
|
|
||||||
def formatted_outputs(reply, model_name):
|
def formatted_outputs(reply, model_name):
|
||||||
if shared.model_type == 'gpt4chan':
|
if any(s in model_name for s in ['gpt-4chan', 'gpt4chan']):
|
||||||
reply = fix_gpt4chan(reply)
|
reply = fix_gpt4chan(reply)
|
||||||
return reply, generate_4chan_html(reply)
|
return reply, generate_4chan_html(reply)
|
||||||
else:
|
else:
|
||||||
@ -142,7 +142,7 @@ def stop_everything_event():
|
|||||||
|
|
||||||
def generate_reply_wrapper(question, state, eos_token=None, stopping_strings=None):
|
def generate_reply_wrapper(question, state, eos_token=None, stopping_strings=None):
|
||||||
for reply in generate_reply(question, state, eos_token, stopping_strings, is_chat=False):
|
for reply in generate_reply(question, state, eos_token, stopping_strings, is_chat=False):
|
||||||
if shared.model_type not in ['HF_seq2seq']:
|
if not shared.is_seq2seq:
|
||||||
reply = question + reply
|
reply = question + reply
|
||||||
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
@ -157,7 +157,7 @@ def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_c
|
|||||||
yield ''
|
yield ''
|
||||||
return
|
return
|
||||||
|
|
||||||
if shared.model_type in ['rwkv', 'llamacpp']:
|
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel']:
|
||||||
generate_func = generate_reply_custom
|
generate_func = generate_reply_custom
|
||||||
elif shared.args.flexgen:
|
elif shared.args.flexgen:
|
||||||
generate_func = generate_reply_flexgen
|
generate_func = generate_reply_flexgen
|
||||||
@ -240,7 +240,7 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
try:
|
try:
|
||||||
if not is_chat and shared.model_type != 'HF_seq2seq':
|
if not is_chat and not shared.is_seq2seq:
|
||||||
yield ''
|
yield ''
|
||||||
|
|
||||||
# Generate the entire reply at once.
|
# Generate the entire reply at once.
|
||||||
@ -276,7 +276,7 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||||||
finally:
|
finally:
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
original_tokens = len(original_input_ids[0])
|
original_tokens = len(original_input_ids[0])
|
||||||
new_tokens = len(output) - (original_tokens if shared.model_type != 'HF_seq2seq' else 0)
|
new_tokens = len(output) - (original_tokens if not shared.is_seq2seq else 0)
|
||||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -287,7 +287,7 @@ def generate_reply_custom(question, original_question, seed, state, eos_token=No
|
|||||||
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||||
generate_params[k] = state[k]
|
generate_params[k] = state[k]
|
||||||
|
|
||||||
if shared.model_type == 'llamacpp':
|
if shared.model.__class__.__name__ in ['LlamaCppModel']:
|
||||||
for k in ['mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
|
for k in ['mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
|
||||||
generate_params[k] = state[k]
|
generate_params[k] = state[k]
|
||||||
|
|
||||||
@ -381,6 +381,6 @@ def generate_reply_flexgen(question, original_question, seed, state, eos_token=N
|
|||||||
finally:
|
finally:
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
original_tokens = len(original_input_ids[0])
|
original_tokens = len(original_input_ids[0])
|
||||||
new_tokens = len(output) - (original_tokens if shared.model_type != 'HF_seq2seq' else 0)
|
new_tokens = len(output) - (original_tokens if not shared.is_seq2seq else 0)
|
||||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||||
return
|
return
|
||||||
|
@ -30,7 +30,7 @@ theme = gr.themes.Default(
|
|||||||
|
|
||||||
|
|
||||||
def list_model_elements():
|
def list_model_elements():
|
||||||
elements = ['cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'trust_remote_code', 'load_in_4bit', 'compute_dtype', 'quant_type', 'use_double_quant', 'gptq_for_llama', 'wbits', 'groupsize', 'model_type', 'pre_layer', 'triton', 'desc_act', 'no_inject_fused_attention', 'no_inject_fused_mlp', 'threads', 'n_batch', 'no_mmap', 'mlock', 'n_gpu_layers', 'n_ctx', 'llama_cpp_seed']
|
elements = ['loader', 'cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'trust_remote_code', 'load_in_4bit', 'compute_dtype', 'quant_type', 'use_double_quant', 'wbits', 'groupsize', 'model_type', 'pre_layer', 'triton', 'desc_act', 'no_inject_fused_attention', 'no_inject_fused_mlp', 'threads', 'n_batch', 'no_mmap', 'mlock', 'n_gpu_layers', 'n_ctx', 'llama_cpp_seed']
|
||||||
for i in range(torch.cuda.device_count()):
|
for i in range(torch.cuda.device_count()):
|
||||||
elements.append(f'gpu_memory_{i}')
|
elements.append(f'gpu_memory_{i}')
|
||||||
|
|
||||||
|
200
server.py
200
server.py
@ -43,17 +43,21 @@ import yaml
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import modules.extensions as extensions_module
|
import modules.extensions as extensions_module
|
||||||
from modules import chat, presets, shared, training, ui, utils
|
from modules import chat, loaders, presets, shared, training, ui, utils
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.github import clone_or_pull_repository
|
from modules.github import clone_or_pull_repository
|
||||||
from modules.html_generator import chat_html_wrapper
|
from modules.html_generator import chat_html_wrapper
|
||||||
from modules.LoRA import add_lora_to_model
|
from modules.LoRA import add_lora_to_model
|
||||||
from modules.models import load_model, unload_model
|
from modules.models import load_model, unload_model
|
||||||
|
from modules.models_settings import (apply_model_settings_to_state,
|
||||||
|
get_model_settings_from_yamls,
|
||||||
|
save_model_settings,
|
||||||
|
update_model_parameters)
|
||||||
from modules.text_generation import (generate_reply_wrapper,
|
from modules.text_generation import (generate_reply_wrapper,
|
||||||
get_encoded_length, stop_everything_event)
|
get_encoded_length, stop_everything_event)
|
||||||
|
|
||||||
|
|
||||||
def load_model_wrapper(selected_model, autoload=False):
|
def load_model_wrapper(selected_model, loader, autoload=False):
|
||||||
if not autoload:
|
if not autoload:
|
||||||
yield f"The settings for {selected_model} have been updated.\nClick on \"Load the model\" to load it."
|
yield f"The settings for {selected_model} have been updated.\nClick on \"Load the model\" to load it."
|
||||||
return
|
return
|
||||||
@ -66,9 +70,12 @@ def load_model_wrapper(selected_model, autoload=False):
|
|||||||
shared.model_name = selected_model
|
shared.model_name = selected_model
|
||||||
unload_model()
|
unload_model()
|
||||||
if selected_model != '':
|
if selected_model != '':
|
||||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
shared.model, shared.tokenizer = load_model(shared.model_name, loader)
|
||||||
|
|
||||||
yield f"Successfully loaded {selected_model}"
|
if shared.model is not None:
|
||||||
|
yield f"Successfully loaded {selected_model}"
|
||||||
|
else:
|
||||||
|
yield f"Failed to load {selected_model}."
|
||||||
except:
|
except:
|
||||||
yield traceback.format_exc()
|
yield traceback.format_exc()
|
||||||
|
|
||||||
@ -144,103 +151,6 @@ def download_model_wrapper(repo_id):
|
|||||||
yield traceback.format_exc()
|
yield traceback.format_exc()
|
||||||
|
|
||||||
|
|
||||||
# Update the command-line arguments based on the interface values
|
|
||||||
def update_model_parameters(state, initial=False):
|
|
||||||
elements = ui.list_model_elements() # the names of the parameters
|
|
||||||
gpu_memories = []
|
|
||||||
|
|
||||||
for i, element in enumerate(elements):
|
|
||||||
if element not in state:
|
|
||||||
continue
|
|
||||||
|
|
||||||
value = state[element]
|
|
||||||
if element.startswith('gpu_memory'):
|
|
||||||
gpu_memories.append(value)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if initial and vars(shared.args)[element] != vars(shared.args_defaults)[element]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Setting null defaults
|
|
||||||
if element in ['wbits', 'groupsize', 'model_type'] and value == 'None':
|
|
||||||
value = vars(shared.args_defaults)[element]
|
|
||||||
elif element in ['cpu_memory'] and value == 0:
|
|
||||||
value = vars(shared.args_defaults)[element]
|
|
||||||
|
|
||||||
# Making some simple conversions
|
|
||||||
if element in ['wbits', 'groupsize', 'pre_layer']:
|
|
||||||
value = int(value)
|
|
||||||
elif element == 'cpu_memory' and value is not None:
|
|
||||||
value = f"{value}MiB"
|
|
||||||
|
|
||||||
if element in ['pre_layer']:
|
|
||||||
value = [value] if value > 0 else None
|
|
||||||
|
|
||||||
setattr(shared.args, element, value)
|
|
||||||
|
|
||||||
found_positive = False
|
|
||||||
for i in gpu_memories:
|
|
||||||
if i > 0:
|
|
||||||
found_positive = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if not (initial and vars(shared.args)['gpu_memory'] != vars(shared.args_defaults)['gpu_memory']):
|
|
||||||
if found_positive:
|
|
||||||
shared.args.gpu_memory = [f"{i}MiB" for i in gpu_memories]
|
|
||||||
else:
|
|
||||||
shared.args.gpu_memory = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_specific_settings(model):
|
|
||||||
settings = shared.model_config
|
|
||||||
model_settings = {}
|
|
||||||
|
|
||||||
for pat in settings:
|
|
||||||
if re.match(pat.lower(), model.lower()):
|
|
||||||
for k in settings[pat]:
|
|
||||||
model_settings[k] = settings[pat][k]
|
|
||||||
|
|
||||||
return model_settings
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_specific_settings(model, state):
|
|
||||||
model_settings = get_model_specific_settings(model)
|
|
||||||
for k in model_settings:
|
|
||||||
if k in state:
|
|
||||||
state[k] = model_settings[k]
|
|
||||||
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
def save_model_settings(model, state):
|
|
||||||
if model == 'None':
|
|
||||||
yield ("Not saving the settings because no model is loaded.")
|
|
||||||
return
|
|
||||||
|
|
||||||
with Path(f'{shared.args.model_dir}/config-user.yaml') as p:
|
|
||||||
if p.exists():
|
|
||||||
user_config = yaml.safe_load(open(p, 'r').read())
|
|
||||||
else:
|
|
||||||
user_config = {}
|
|
||||||
|
|
||||||
model_regex = model + '$' # For exact matches
|
|
||||||
for _dict in [user_config, shared.model_config]:
|
|
||||||
if model_regex not in _dict:
|
|
||||||
_dict[model_regex] = {}
|
|
||||||
|
|
||||||
if model_regex not in user_config:
|
|
||||||
user_config[model_regex] = {}
|
|
||||||
|
|
||||||
for k in ui.list_model_elements():
|
|
||||||
user_config[model_regex][k] = state[k]
|
|
||||||
shared.model_config[model_regex][k] = state[k]
|
|
||||||
|
|
||||||
with open(p, 'w') as f:
|
|
||||||
f.write(yaml.dump(user_config, sort_keys=False))
|
|
||||||
|
|
||||||
yield (f"Settings for {model} saved to {p}")
|
|
||||||
|
|
||||||
|
|
||||||
def create_model_menus():
|
def create_model_menus():
|
||||||
# Finding the default values for the GPU and CPU memories
|
# Finding the default values for the GPU and CPU memories
|
||||||
total_mem = []
|
total_mem = []
|
||||||
@ -283,88 +193,70 @@ def create_model_menus():
|
|||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "AutoGPTQ", "GPTQ-for-LLaMa", "llama.cpp"], value=None)
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
gr.Markdown('Transformers')
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
for i in range(len(total_mem)):
|
for i in range(len(total_mem)):
|
||||||
shared.gradio[f'gpu_memory_{i}'] = gr.Slider(label=f"gpu-memory in MiB for device :{i}", maximum=total_mem[i], value=default_gpu_mem[i])
|
shared.gradio[f'gpu_memory_{i}'] = gr.Slider(label=f"gpu-memory in MiB for device :{i}", maximum=total_mem[i], value=default_gpu_mem[i])
|
||||||
|
|
||||||
shared.gradio['cpu_memory'] = gr.Slider(label="cpu-memory in MiB", maximum=total_cpu_mem, value=default_cpu_mem)
|
shared.gradio['cpu_memory'] = gr.Slider(label="cpu-memory in MiB", maximum=total_cpu_mem, value=default_cpu_mem)
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices)
|
|
||||||
shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
|
|
||||||
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu)
|
|
||||||
shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
|
|
||||||
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
|
|
||||||
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Make sure to inspect the .py files inside the model folder before loading it with this option enabled.')
|
|
||||||
|
|
||||||
with gr.Box():
|
|
||||||
gr.Markdown('Transformers 4-bit')
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
|
|
||||||
shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant)
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype)
|
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype)
|
||||||
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type)
|
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type)
|
||||||
|
|
||||||
shared.gradio['autoload_model'] = gr.Checkbox(value=shared.settings['autoload_model'], label='Autoload the model', info='Whether to load the model as soon as it is selected in the Model dropdown.')
|
|
||||||
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main")
|
|
||||||
shared.gradio['download_model_button'] = gr.Button("Download")
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Box():
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
gr.Markdown('GPTQ')
|
|
||||||
shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton)
|
|
||||||
shared.gradio['no_inject_fused_attention'] = gr.Checkbox(label="no_inject_fused_attention", value=shared.args.no_inject_fused_attention, info='Disable fused attention. Fused attention improves inference performance but uses more VRAM. Disable if running low on VRAM.')
|
|
||||||
shared.gradio['no_inject_fused_mlp'] = gr.Checkbox(label="no_inject_fused_mlp", value=shared.args.no_inject_fused_mlp, info='Affects Triton only. Disable fused MLP. Fused MLP improves performance but uses more VRAM. Disable if running low on VRAM.')
|
|
||||||
shared.gradio['desc_act'] = gr.Checkbox(label="desc_act", value=shared.args.desc_act, info='\'desc_act\', \'wbits\', and \'groupsize\' are used for old models without a quantize_config.json.')
|
|
||||||
shared.gradio['gptq_for_llama'] = gr.Checkbox(label="gptq-for-llama", value=shared.args.gptq_for_llama, info='Use GPTQ-for-LLaMa loader instead of AutoGPTQ. pre_layer should be used for CPU offloading instead of gpu-memory.')
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Row():
|
|
||||||
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=shared.args.wbits if shared.args.wbits > 0 else "None")
|
|
||||||
shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=shared.args.groupsize if shared.args.groupsize > 0 else "None")
|
|
||||||
|
|
||||||
shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None")
|
|
||||||
shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0)
|
|
||||||
|
|
||||||
with gr.Box():
|
|
||||||
gr.Markdown('llama.cpp')
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=32, value=shared.args.threads)
|
shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=32, value=shared.args.threads)
|
||||||
shared.gradio['n_batch'] = gr.Slider(label="n_batch", minimum=1, maximum=2048, value=shared.args.n_batch)
|
shared.gradio['n_batch'] = gr.Slider(label="n_batch", minimum=1, maximum=2048, value=shared.args.n_batch)
|
||||||
shared.gradio['n_gpu_layers'] = gr.Slider(label="n-gpu-layers", minimum=0, maximum=128, value=shared.args.n_gpu_layers)
|
shared.gradio['n_gpu_layers'] = gr.Slider(label="n-gpu-layers", minimum=0, maximum=128, value=shared.args.n_gpu_layers)
|
||||||
shared.gradio['n_ctx'] = gr.Slider(minimum=0, maximum=8192, step=1, label="n_ctx", value=shared.args.n_ctx)
|
shared.gradio['n_ctx'] = gr.Slider(minimum=0, maximum=8192, step=1, label="n_ctx", value=shared.args.n_ctx)
|
||||||
|
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=shared.args.wbits if shared.args.wbits > 0 else "None")
|
||||||
|
shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=shared.args.groupsize if shared.args.groupsize > 0 else "None")
|
||||||
|
shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None")
|
||||||
|
shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0)
|
||||||
|
shared.gradio['autogptq_info'] = gr.Markdown('On some systems, AutoGPTQ can be 2x slower than GPTQ-for-LLaMa. You can manually select the GPTQ-for-LLaMa loader above.')
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton)
|
||||||
|
shared.gradio['no_inject_fused_attention'] = gr.Checkbox(label="no_inject_fused_attention", value=shared.args.no_inject_fused_attention, info='Disable fused attention. Fused attention improves inference performance but uses more VRAM. Disable if running low on VRAM.')
|
||||||
|
shared.gradio['no_inject_fused_mlp'] = gr.Checkbox(label="no_inject_fused_mlp", value=shared.args.no_inject_fused_mlp, info='Affects Triton only. Disable fused MLP. Fused MLP improves performance but uses more VRAM. Disable if running low on VRAM.')
|
||||||
|
shared.gradio['desc_act'] = gr.Checkbox(label="desc_act", value=shared.args.desc_act, info='\'desc_act\', \'wbits\', and \'groupsize\' are used for old models without a quantize_config.json.')
|
||||||
|
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
|
||||||
|
shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
|
||||||
|
shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices)
|
||||||
|
shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
|
||||||
|
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu)
|
||||||
|
shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit")
|
||||||
|
shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant)
|
||||||
shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap)
|
shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap)
|
||||||
shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
|
shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
|
||||||
shared.gradio['llama_cpp_seed'] = gr.Number(label='Seed (0 for random)', value=shared.args.llama_cpp_seed)
|
shared.gradio['llama_cpp_seed'] = gr.Number(label='Seed (0 for random)', value=shared.args.llama_cpp_seed)
|
||||||
|
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Make sure to inspect the .py files inside the model folder before loading it with this option enabled.')
|
||||||
|
shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa is currently 2x faster than AutoGPTQ on some systems. It is installed by default with the one-click installers. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).')
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['autoload_model'] = gr.Checkbox(value=shared.settings['autoload_model'], label='Autoload the model', info='Whether to load the model as soon as it is selected in the Model dropdown.')
|
||||||
|
|
||||||
|
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main")
|
||||||
|
shared.gradio['download_model_button'] = gr.Button("Download")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
|
shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
|
||||||
|
|
||||||
|
shared.gradio['loader'].change(loaders.make_loader_params_visible, shared.gradio['loader'], [shared.gradio[k] for k in loaders.get_all_params()])
|
||||||
|
|
||||||
# In this event handler, the interface state is read and updated
|
# In this event handler, the interface state is read and updated
|
||||||
# with the model defaults (if any), and then the model is loaded
|
# with the model defaults (if any), and then the model is loaded
|
||||||
# unless "autoload_model" is unchecked
|
# unless "autoload_model" is unchecked
|
||||||
shared.gradio['model_menu'].change(
|
shared.gradio['model_menu'].change(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
load_model_specific_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['interface_state']).then(
|
apply_model_settings_to_state, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['interface_state']).then(
|
||||||
ui.apply_interface_values, shared.gradio['interface_state'], [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False).then(
|
ui.apply_interface_values, shared.gradio['interface_state'], [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False).then(
|
||||||
update_model_parameters, shared.gradio['interface_state'], None).then(
|
update_model_parameters, shared.gradio['interface_state'], None).then(
|
||||||
load_model_wrapper, [shared.gradio[k] for k in ['model_menu', 'autoload_model']], shared.gradio['model_status'], show_progress=False)
|
load_model_wrapper, [shared.gradio[k] for k in ['model_menu', 'loader', 'autoload_model']], shared.gradio['model_status'], show_progress=False)
|
||||||
|
|
||||||
load.click(
|
load.click(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
update_model_parameters, shared.gradio['interface_state'], None).then(
|
update_model_parameters, shared.gradio['interface_state'], None).then(
|
||||||
partial(load_model_wrapper, autoload=True), shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=False)
|
partial(load_model_wrapper, autoload=True), [shared.gradio[k] for k in ['model_menu', 'loader']], shared.gradio['model_status'], show_progress=False)
|
||||||
|
|
||||||
unload.click(
|
unload.click(
|
||||||
unload_model, None, None).then(
|
unload_model, None, None).then(
|
||||||
@ -374,7 +266,7 @@ def create_model_menus():
|
|||||||
unload_model, None, None).then(
|
unload_model, None, None).then(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
update_model_parameters, shared.gradio['interface_state'], None).then(
|
update_model_parameters, shared.gradio['interface_state'], None).then(
|
||||||
partial(load_model_wrapper, autoload=True), shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=False)
|
partial(load_model_wrapper, autoload=True), [shared.gradio[k] for k in ['model_menu', 'loader']], shared.gradio['model_status'], show_progress=False)
|
||||||
|
|
||||||
save_settings.click(
|
save_settings.click(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
@ -1100,7 +992,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# If any model has been selected, load it
|
# If any model has been selected, load it
|
||||||
if shared.model_name != 'None':
|
if shared.model_name != 'None':
|
||||||
model_settings = get_model_specific_settings(shared.model_name)
|
model_settings = get_model_settings_from_yamls(shared.model_name)
|
||||||
shared.settings.update(model_settings) # hijacking the interface defaults
|
shared.settings.update(model_settings) # hijacking the interface defaults
|
||||||
update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments
|
update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments
|
||||||
|
|
||||||
@ -1117,6 +1009,10 @@ if __name__ == "__main__":
|
|||||||
'instruction_template': shared.settings['instruction_template']
|
'instruction_template': shared.settings['instruction_template']
|
||||||
})
|
})
|
||||||
|
|
||||||
|
shared.persistent_interface_state.update({
|
||||||
|
'loader': shared.args.loader or 'Transformers',
|
||||||
|
})
|
||||||
|
|
||||||
shared.generation_lock = Lock()
|
shared.generation_lock = Lock()
|
||||||
# Launch the web UI
|
# Launch the web UI
|
||||||
create_interface()
|
create_interface()
|
||||||
|
Loading…
Reference in New Issue
Block a user