Merge branch 'main' into Brawlence-main

This commit is contained in:
oobabooga 2023-03-26 23:40:51 -03:00
commit e07c9e3093
20 changed files with 270 additions and 203 deletions

29
.gitignore vendored
View File

@ -1,26 +1,21 @@
cache/* cache
characters/* characters
extensions/silero_tts/outputs/* training/datasets
extensions/elevenlabs_tts/outputs/* extensions/silero_tts/outputs
extensions/sd_api_pictures/outputs/* extensions/elevenlabs_tts/outputs
logs/* extensions/sd_api_pictures/outputs
loras/* logs
models/* loras
softprompts/* models
torch-dumps/* softprompts
torch-dumps
*pycache* *pycache*
*/*pycache* */*pycache*
*/*/pycache* */*/pycache*
venv/ venv/
.venv/ .venv/
repositories
settings.json settings.json
img_bot* img_bot*
img_me* img_me*
!characters/Example.json
!characters/Example.png
!loras/place-your-loras-here.txt
!models/place-your-models-here.txt
!softprompts/place-your-softprompts-here.txt
!torch-dumps/place-your-pt-models-here.txt

View File

@ -27,7 +27,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen). * [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen).
* [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed). * [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed).
* Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming. * Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming.
* [LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model). * [LLaMA model, including 4-bit GPTQ support](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model).
* [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model). * [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model).
* [Supports LoRAs](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs). * [Supports LoRAs](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs).
* Supports softprompts. * Supports softprompts.
@ -84,10 +84,6 @@ pip install -r requirements.txt
> >
> For bitsandbytes and `--load-in-8bit` to work on Linux/WSL, this dirty fix is currently necessary: https://github.com/oobabooga/text-generation-webui/issues/400#issuecomment-1474876859 > For bitsandbytes and `--load-in-8bit` to work on Linux/WSL, this dirty fix is currently necessary: https://github.com/oobabooga/text-generation-webui/issues/400#issuecomment-1474876859
### Alternative: native Windows installation
As an alternative to the recommended WSL method, you can install the web UI natively on Windows using this guide. It will be a lot harder and the performance may be slower: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings).
### Alternative: one-click installers ### Alternative: one-click installers
[oobabooga-windows.zip](https://github.com/oobabooga/one-click-installers/archive/refs/heads/oobabooga-windows.zip) [oobabooga-windows.zip](https://github.com/oobabooga/one-click-installers/archive/refs/heads/oobabooga-windows.zip)
@ -101,7 +97,13 @@ Just download the zip above, extract it, and double click on "install". The web
Source codes: https://github.com/oobabooga/one-click-installers Source codes: https://github.com/oobabooga/one-click-installers
This method lags behind the newest developments and does not support 8-bit mode on Windows without additional set up: https://github.com/oobabooga/text-generation-webui/issues/147#issuecomment-1456040134, https://github.com/oobabooga/text-generation-webui/issues/20#issuecomment-1411650652 > **Note**
>
> To get 8-bit and 4-bit models working in your 1-click Windows installation, you can use the [one-click-bandaid](https://github.com/ClayShoaf/oobabooga-one-click-bandaid).
### Alternative: native Windows installation
As an alternative to the recommended WSL method, you can install the web UI natively on Windows using this guide. It will be a lot harder and the performance may be slower: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings).
### Alternative: Docker ### Alternative: Docker
@ -174,10 +176,10 @@ Optionally, you can use the following command-line flags:
| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. | | `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
| `--cpu` | Use the CPU to generate text.| | `--cpu` | Use the CPU to generate text.|
| `--load-in-8bit` | Load the model with 8-bit precision.| | `--load-in-8bit` | Load the model with 8-bit precision.|
| `--load-in-4bit` | DEPRECATED: use `--gptq-bits 4` instead. | | `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. |
| `--gptq-bits GPTQ_BITS` | GPTQ: Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA and OPT. | | `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently only LLaMA and OPT are supported. |
| `--gptq-model-type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently only LLaMa and OPT are supported. | | `--groupsize GROUPSIZE` | GPTQ: Group size. |
| `--gptq-pre-layer GPTQ_PRE_LAYER` | GPTQ: The number of layers to preload. | | `--pre_layer PRE_LAYER` | GPTQ: The number of layers to preload. |
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. | | `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.| | `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. | | `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |

View File

@ -23,3 +23,9 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
.pending.svelte-1ed2p3z { .pending.svelte-1ed2p3z {
opacity: 1; opacity: 1;
} }
#extensions {
padding: 0;
padding: 0;
}

View File

@ -54,3 +54,13 @@ ol li p, ul li p {
.gradio-container-3-18-0 .prose * h1, h2, h3, h4 { .gradio-container-3-18-0 .prose * h1, h2, h3, h4 {
color: white; color: white;
} }
.gradio-container {
max-width: 100% !important;
padding-top: 0 !important;
}
#extensions {
padding: 15px;
padding: 15px;
}

View File

@ -11,7 +11,7 @@ let extensions = document.getElementById('extensions');
main_parent.addEventListener('click', function(e) { main_parent.addEventListener('click', function(e) {
// Check if the main element is visible // Check if the main element is visible
if (main.offsetHeight > 0 && main.offsetWidth > 0) { if (main.offsetHeight > 0 && main.offsetWidth > 0) {
extensions.style.display = 'block'; extensions.style.display = 'flex';
} else { } else {
extensions.style.display = 'none'; extensions.style.display = 'none';
} }

View File

@ -116,10 +116,11 @@ def get_download_links_from_huggingface(model, branch):
is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname) is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
is_safetensors = re.match("model.*\.safetensors", fname) is_safetensors = re.match("model.*\.safetensors", fname)
is_pt = re.match(".*\.pt", fname)
is_tokenizer = re.match("tokenizer.*\.model", fname) is_tokenizer = re.match("tokenizer.*\.model", fname)
is_text = re.match(".*\.(txt|json|py)", fname) or is_tokenizer is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
if any((is_pytorch, is_safetensors, is_text, is_tokenizer)): if any((is_pytorch, is_safetensors, is_pt, is_tokenizer, is_text)):
if is_text: if is_text:
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}") links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
classifications.append('text') classifications.append('text')
@ -132,7 +133,8 @@ def get_download_links_from_huggingface(model, branch):
elif is_pytorch: elif is_pytorch:
has_pytorch = True has_pytorch = True
classifications.append('pytorch') classifications.append('pytorch')
elif is_pt:
classifications.append('pt')
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50' cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
cursor = base64.b64encode(cursor) cursor = base64.b64encode(cursor)

View File

@ -1,8 +1,9 @@
import json
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread from threading import Thread
from modules import shared from modules import shared
from modules.text_generation import generate_reply, encode from modules.text_generation import encode, generate_reply
import json
params = { params = {
'port': 5000, 'port': 5000,
@ -87,5 +88,5 @@ def run_server():
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api') print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
server.serve_forever() server.serve_forever()
def ui(): def setup():
Thread(target=run_server, daemon=True).start() Thread(target=run_server, daemon=True).start()

View File

@ -14,18 +14,21 @@ import opt
def load_quantized(model_name): def load_quantized(model_name):
if not shared.args.gptq_model_type: if not shared.args.model_type:
# Try to determine model type from model name # Try to determine model type from model name
model_type = model_name.split('-')[0].lower() if model_name.lower().startswith(('llama', 'alpaca')):
if model_type not in ('llama', 'opt'): model_type = 'llama'
print("Can't determine model type from model name. Please specify it manually using --gptq-model-type " elif model_name.lower().startswith(('opt', 'galactica')):
model_type = 'opt'
else:
print("Can't determine model type from model name. Please specify it manually using --model_type "
"argument") "argument")
exit() exit()
else: else:
model_type = shared.args.gptq_model_type.lower() model_type = shared.args.model_type.lower()
if model_type == 'llama': if model_type == 'llama':
if not shared.args.gptq_pre_layer: if not shared.args.pre_layer:
load_quant = llama.load_quant load_quant = llama.load_quant
else: else:
load_quant = llama_inference_offload.load_quant load_quant = llama_inference_offload.load_quant
@ -35,33 +38,44 @@ def load_quantized(model_name):
print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported") print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported")
exit() exit()
# Now we are going to try to locate the quantized model file.
path_to_model = Path(f'models/{model_name}') path_to_model = Path(f'models/{model_name}')
if path_to_model.name.lower().startswith('llama-7b'): found_pts = list(path_to_model.glob("*.pt"))
pt_model = f'llama-7b-{shared.args.gptq_bits}bit.pt' found_safetensors = list(path_to_model.glob("*.safetensors"))
elif path_to_model.name.lower().startswith('llama-13b'):
pt_model = f'llama-13b-{shared.args.gptq_bits}bit.pt'
elif path_to_model.name.lower().startswith('llama-30b'):
pt_model = f'llama-30b-{shared.args.gptq_bits}bit.pt'
elif path_to_model.name.lower().startswith('llama-65b'):
pt_model = f'llama-65b-{shared.args.gptq_bits}bit.pt'
else:
pt_model = f'{model_name}-{shared.args.gptq_bits}bit.pt'
# Try to find the .pt both in models/ and in the subfolder
pt_path = None pt_path = None
for path in [Path(p) for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]:
if path.exists(): if len(found_pts) == 1:
pt_path = path pt_path = found_pts[0]
elif len(found_safetensors) == 1:
pt_path = found_safetensors[0]
else:
if path_to_model.name.lower().startswith('llama-7b'):
pt_model = f'llama-7b-{shared.args.wbits}bit'
elif path_to_model.name.lower().startswith('llama-13b'):
pt_model = f'llama-13b-{shared.args.wbits}bit'
elif path_to_model.name.lower().startswith('llama-30b'):
pt_model = f'llama-30b-{shared.args.wbits}bit'
elif path_to_model.name.lower().startswith('llama-65b'):
pt_model = f'llama-65b-{shared.args.wbits}bit'
else:
pt_model = f'{model_name}-{shared.args.wbits}bit'
# Try to find the .safetensors or .pt both in models/ and in the subfolder
for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]:
if path.exists():
print(f"Found {path}")
pt_path = path
break
if not pt_path: if not pt_path:
print(f"Could not find {pt_model}, exiting...") print("Could not find the quantized model in .pt or .safetensors format, exiting...")
exit() exit()
# qwopqwop200's offload # qwopqwop200's offload
if shared.args.gptq_pre_layer: if shared.args.pre_layer:
model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits, shared.args.gptq_pre_layer) model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
else: else:
model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits) model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize)
# accelerate offload (doesn't work properly) # accelerate offload (doesn't work properly)
if shared.args.gpu_memory: if shared.args.gpu_memory:

View File

@ -1,22 +1,43 @@
from pathlib import Path from pathlib import Path
import torch
import modules.shared as shared import modules.shared as shared
from modules.models import load_model from modules.models import load_model
from modules.text_generation import clear_torch_cache
def reload_model():
shared.model = shared.tokenizer = None
clear_torch_cache()
shared.model, shared.tokenizer = load_model(shared.model_name)
def add_lora_to_model(lora_name): def add_lora_to_model(lora_name):
from peft import PeftModel from peft import PeftModel
# Is there a more efficient way of returning to the base model? # If a LoRA had been previously loaded, or if we want
if lora_name == "None": # to unload a LoRA, reload the model
print("Reloading the model to remove the LoRA...") if shared.lora_name != "None" or lora_name == "None":
shared.model, shared.tokenizer = load_model(shared.model_name) reload_model()
else: shared.lora_name = lora_name
# Why doesn't this work in 16-bit mode?
print(f"Adding the LoRA {lora_name} to the model...")
if lora_name != "None":
print(f"Adding the LoRA {lora_name} to the model...")
params = {} params = {}
params['device_map'] = {'': 0} if not shared.args.cpu:
#params['dtype'] = shared.model.dtype params['dtype'] = shared.model.dtype
if hasattr(shared.model, "hf_device_map"):
params['device_map'] = {"base_model.model."+k: v for k, v in shared.model.hf_device_map.items()}
elif shared.args.load_in_8bit:
params['device_map'] = {'': 0}
shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params) shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params)
if not shared.args.load_in_8bit and not shared.args.cpu:
shared.model.half()
if not hasattr(shared.model, "hf_device_map"):
if torch.has_mps:
device = torch.device('mps')
shared.model = shared.model.to(device)
else:
shared.model = shared.model.cuda()

View File

@ -45,11 +45,11 @@ class RWKVModel:
token_stop = token_stop token_stop = token_stop
) )
return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) return self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
def generate_with_streaming(self, **kwargs): def generate_with_streaming(self, **kwargs):
with Iteratorize(self.generate, kwargs, callback=None) as generator: with Iteratorize(self.generate, kwargs, callback=None) as generator:
reply = kwargs['context'] reply = ''
for token in generator: for token in generator:
reply += token reply += token
yield reply yield reply

View File

@ -11,24 +11,22 @@ import modules.shared as shared
# Copied from https://github.com/PygmalionAI/gradio-ui/ # Copied from https://github.com/PygmalionAI/gradio-ui/
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, sentinel_token_ids: torch.LongTensor, def __init__(self, sentinel_token_ids: list[torch.LongTensor], starting_idx: int):
starting_idx: int):
transformers.StoppingCriteria.__init__(self) transformers.StoppingCriteria.__init__(self)
self.sentinel_token_ids = sentinel_token_ids self.sentinel_token_ids = sentinel_token_ids
self.starting_idx = starting_idx self.starting_idx = starting_idx
def __call__(self, input_ids: torch.LongTensor, def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
_scores: torch.FloatTensor) -> bool:
for sample in input_ids: for sample in input_ids:
trimmed_sample = sample[self.starting_idx:] trimmed_sample = sample[self.starting_idx:]
# Can't unfold, output is still too tiny. Skip.
if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
continue
for window in trimmed_sample.unfold( for i in range(len(self.sentinel_token_ids)):
0, self.sentinel_token_ids.shape[-1], 1): # Can't unfold, output is still too tiny. Skip.
if torch.all(torch.eq(self.sentinel_token_ids, window)): if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
return True continue
for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
return True
return False return False
class Stream(transformers.StoppingCriteria): class Stream(transformers.StoppingCriteria):

View File

@ -33,12 +33,14 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
i = len(shared.history['internal'])-1 i = len(shared.history['internal'])-1
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length: while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n") rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n")
if not (shared.history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'): prev_user_input = shared.history['internal'][i][0]
rows.insert(1, f"{name1}: {shared.history['internal'][i][0].strip()}\n") if len(prev_user_input) > 0 and prev_user_input != '<|BEGIN-VISIBLE-CHAT|>':
rows.insert(1, f"{name1}: {prev_user_input.strip()}\n")
i -= 1 i -= 1
if not impersonate: if not impersonate:
rows.append(f"{name1}: {user_input}\n") if len(user_input) > 0:
rows.append(f"{name1}: {user_input}\n")
rows.append(apply_extensions(f"{name2}:", "bot_prefix")) rows.append(apply_extensions(f"{name2}:", "bot_prefix"))
limit = 3 limit = 3
else: else:
@ -51,41 +53,31 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
prompt = ''.join(rows) prompt = ''.join(rows)
return prompt return prompt
def extract_message_from_reply(question, reply, name1, name2, check, impersonate=False): def extract_message_from_reply(reply, name1, name2, check):
next_character_found = False next_character_found = False
asker = name1 if not impersonate else name2
replier = name2 if not impersonate else name1
previous_idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(replier)}:", question)]
idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(replier)}:", reply)]
idx = idx[max(len(previous_idx)-1, 0)]
if not impersonate:
reply = reply[idx + 1 + len(apply_extensions(f"{replier}:", "bot_prefix")):]
else:
reply = reply[idx + 1 + len(f"{replier}:"):]
if check: if check:
lines = reply.split('\n') lines = reply.split('\n')
reply = lines[0].strip() reply = lines[0].strip()
if len(lines) > 1: if len(lines) > 1:
next_character_found = True next_character_found = True
else: else:
idx = reply.find(f"\n{asker}:") for string in [f"\n{name1}:", f"\n{name2}:"]:
if idx != -1: idx = reply.find(string)
reply = reply[:idx] if idx != -1:
next_character_found = True reply = reply[:idx]
reply = fix_newlines(reply) next_character_found = True
# If something like "\nYo" is generated just before "\nYou:" # If something like "\nYo" is generated just before "\nYou:"
# is completed, trim it # is completed, trim it
next_turn = f"\n{asker}:" if not next_character_found:
for j in range(len(next_turn)-1, 0, -1): for string in [f"\n{name1}:", f"\n{name2}:"]:
if reply[-j:] == next_turn[:j]: for j in range(len(string)-1, 0, -1):
reply = reply[:-j] if reply[-j:] == string[:j]:
break reply = reply[:-j]
break
reply = fix_newlines(reply)
return reply, next_character_found return reply, next_character_found
def stop_everything_event(): def stop_everything_event():
@ -125,12 +117,13 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
yield shared.history['visible']+[[visible_text, shared.processing_message]] yield shared.history['visible']+[[visible_text, shared.processing_message]]
# Generate # Generate
reply = '' cumulative_reply = ''
for i in range(chat_generation_attempts): for i in range(chat_generation_attempts):
for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_string=f"\n{name1}:"): for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
reply = cumulative_reply + reply
# Extracting the reply # Extracting the reply
reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check) reply, next_character_found = extract_message_from_reply(reply, name1, name2, check)
visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply) visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
visible_reply = apply_extensions(visible_reply, "output") visible_reply = apply_extensions(visible_reply, "output")
if shared.args.chat: if shared.args.chat:
@ -152,6 +145,8 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
if next_character_found: if next_character_found:
break break
cumulative_reply = reply
yield shared.history['visible'] yield shared.history['visible']
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
@ -162,16 +157,21 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True) prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
reply = ''
# Yield *Is typing...* # Yield *Is typing...*
yield shared.processing_message yield shared.processing_message
cumulative_reply = ''
for i in range(chat_generation_attempts): for i in range(chat_generation_attempts):
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_string=f"\n{name2}:"): for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True) reply = cumulative_reply + reply
reply, next_character_found = extract_message_from_reply(reply, name1, name2, check)
yield reply yield reply
if next_character_found: if next_character_found:
break break
yield reply
cumulative_reply = reply
yield reply
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts): for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):

View File

@ -7,6 +7,7 @@ import modules.shared as shared
state = {} state = {}
available_extensions = [] available_extensions = []
setup_called = False
def load_extensions(): def load_extensions():
global state global state
@ -39,6 +40,8 @@ def apply_extensions(text, typ):
return text return text
def create_extensions_block(): def create_extensions_block():
global setup_called
# Updating the default values # Updating the default values
for extension, name in iterator(): for extension, name in iterator():
if hasattr(extension, 'params'): if hasattr(extension, 'params'):
@ -47,10 +50,21 @@ def create_extensions_block():
if _id in shared.settings: if _id in shared.settings:
extension.params[param] = shared.settings[_id] extension.params[param] = shared.settings[_id]
should_display_ui = False
# Running setup function
if not setup_called:
for extension, name in iterator():
if hasattr(extension, "setup"):
extension.setup()
if hasattr(extension, "ui"):
should_display_ui = True
setup_called = True
# Creating the extension ui elements # Creating the extension ui elements
if len(state) > 0: if should_display_ui:
with gr.Box(elem_id="extensions"): with gr.Column(elem_id="extensions"):
gr.Markdown("Extensions")
for extension, name in iterator(): for extension, name in iterator():
gr.Markdown(f"\n### {name}")
if hasattr(extension, "ui"): if hasattr(extension, "ui"):
extension.ui() extension.ui()

View File

@ -119,13 +119,13 @@ def load_html_image(paths):
def generate_chat_html(history, name1, name2, character): def generate_chat_html(history, name1, name2, character):
output = f'<style>{cai_css}</style><div class="chat" id="chat">' output = f'<style>{cai_css}</style><div class="chat" id="chat">'
img_bot = load_html_image([f"characters/{character}.{ext}" for ext in ['png', 'jpg', 'jpeg']] + ["img_bot.png","img_bot.jpg","img_bot.jpeg"]) img_bot = load_html_image([f"characters/{character}.{ext}" for ext in ['png', 'jpg', 'jpeg']] + ["img_bot.png","img_bot.jpg","img_bot.jpeg"])
img_me = load_html_image(["img_me.png", "img_me.jpg", "img_me.jpeg"]) img_me = load_html_image(["img_me.png", "img_me.jpg", "img_me.jpeg"])
for i,_row in enumerate(history[::-1]): for i,_row in enumerate(history[::-1]):
row = [convert_to_markdown(entry) for entry in _row] row = [convert_to_markdown(entry) for entry in _row]
output += f""" output += f"""
<div class="message"> <div class="message">
<div class="circle-bot"> <div class="circle-bot">
@ -142,22 +142,24 @@ def generate_chat_html(history, name1, name2, character):
</div> </div>
""" """
if not (i == len(history)-1 and len(row[0]) == 0): if len(row[0]) == 0: # don't display empty user messages
output += f""" continue
<div class="message">
<div class="circle-you"> output += f"""
{img_me} <div class="message">
</div> <div class="circle-you">
<div class="text"> {img_me}
<div class="username"> </div>
{name1} <div class="text">
</div> <div class="username">
<div class="message-body"> {name1}
{row[0]}
</div>
</div>
</div> </div>
""" <div class="message-body">
{row[0]}
</div>
</div>
</div>
"""
output += "</div>" output += "</div>"
return output return output

View File

@ -44,7 +44,7 @@ def load_model(model_name):
shared.is_RWKV = model_name.lower().startswith('rwkv-') shared.is_RWKV = model_name.lower().startswith('rwkv-')
# Default settings # Default settings
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.gptq_bits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]): if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')): if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True) model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
else: else:
@ -95,7 +95,7 @@ def load_model(model_name):
return model, tokenizer return model, tokenizer
# Quantized model # Quantized model
elif shared.args.gptq_bits > 0: elif shared.args.wbits > 0:
from modules.GPTQ_loader import load_quantized from modules.GPTQ_loader import load_quantized
model = load_quantized(model_name) model = load_quantized(model_name)

View File

@ -27,9 +27,9 @@ settings = {
'max_new_tokens': 200, 'max_new_tokens': 200,
'max_new_tokens_min': 1, 'max_new_tokens_min': 1,
'max_new_tokens_max': 2000, 'max_new_tokens_max': 2000,
'name1': 'Person 1', 'name1': 'You',
'name2': 'Person 2', 'name2': 'Assistant',
'context': 'This is a conversation between two people.', 'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
'stop_at_newline': False, 'stop_at_newline': False,
'chat_prompt_size': 2048, 'chat_prompt_size': 2048,
'chat_prompt_size_min': 0, 'chat_prompt_size_min': 0,
@ -52,7 +52,8 @@ settings = {
'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:', 'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
'^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n', '^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n',
'(rosey|chip|joi)_.*_instruct.*': 'User: \n', '(rosey|chip|joi)_.*_instruct.*': 'User: \n',
'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>' 'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>',
'alpaca-*': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n",
}, },
'lora_prompts': { 'lora_prompts': {
'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:', 'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
@ -78,10 +79,15 @@ parser.add_argument('--chat', action='store_true', help='Launch the web UI in ch
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.') parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.') parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--load-in-4bit', action='store_true', help='DEPRECATED: use --gptq-bits 4 instead.')
parser.add_argument('--gptq-bits', type=int, default=0, help='GPTQ: Load a pre-quantized model with specified precision. 2, 3, 4 and 8bit are supported. Currently only works with LLaMA and OPT.') parser.add_argument('--gptq-bits', type=int, default=0, help='DEPRECATED: use --wbits instead.')
parser.add_argument('--gptq-model-type', type=str, help='GPTQ: Model type of pre-quantized model. Currently only LLaMa and OPT are supported.') parser.add_argument('--gptq-model-type', type=str, help='DEPRECATED: use --model_type instead.')
parser.add_argument('--gptq-pre-layer', type=int, default=0, help='GPTQ: The number of layers to preload.') parser.add_argument('--gptq-pre-layer', type=int, default=0, help='DEPRECATED: use --pre_layer instead.')
parser.add_argument('--wbits', type=int, default=0, help='GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.')
parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently only LLaMA and OPT are supported.')
parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.')
parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to preload.')
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.') parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.') parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
@ -109,6 +115,8 @@ parser.add_argument('--verbose', action='store_true', help='Print the prompts to
args = parser.parse_args() args = parser.parse_args()
# Provisional, this will be deleted later # Provisional, this will be deleted later
if args.load_in_4bit: deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]}
print("Warning: --load-in-4bit is deprecated and will be removed. Use --gptq-bits 4 instead.\n") for k in deprecated_dict:
args.gptq_bits = 4 if eval(f"args.{k}") != deprecated_dict[k][1]:
print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
exec(f"args.{deprecated_dict[k][0]} = args.{k}")

View File

@ -99,25 +99,37 @@ def set_manual_seed(seed):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_string=None): def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
clear_torch_cache() clear_torch_cache()
set_manual_seed(seed) set_manual_seed(seed)
t0 = time.time() t0 = time.time()
original_question = question
if not (shared.args.chat or shared.args.cai_chat):
question = apply_extensions(question, "input")
if shared.args.verbose:
print(f"\n\n{question}\n--------------------\n")
# These models are not part of Hugging Face, so we handle them # These models are not part of Hugging Face, so we handle them
# separately and terminate the function call earlier # separately and terminate the function call earlier
if shared.is_RWKV: if shared.is_RWKV:
try: try:
if shared.args.no_stream: if shared.args.no_stream:
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply, "output")
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
else: else:
if not (shared.args.chat or shared.args.cai_chat): if not (shared.args.chat or shared.args.cai_chat):
yield formatted_outputs(question, shared.model_name) yield formatted_outputs(question, shared.model_name)
# RWKV has proper streaming, which is very nice. # RWKV has proper streaming, which is very nice.
# No need to generate 8 tokens at a time. # No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k): for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply, "output")
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
finally: finally:
@ -127,12 +139,6 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)") print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
return return
original_question = question
if not (shared.args.chat or shared.args.cai_chat):
question = apply_extensions(question, "input")
if shared.args.verbose:
print(f"\n\n{question}\n--------------------\n")
input_ids = encode(question, max_new_tokens) input_ids = encode(question, max_new_tokens)
original_input_ids = input_ids original_input_ids = input_ids
output = input_ids[0] output = input_ids[0]
@ -142,9 +148,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
if eos_token is not None: if eos_token is not None:
eos_token_ids.append(int(encode(eos_token)[0][-1])) eos_token_ids.append(int(encode(eos_token)[0][-1]))
stopping_criteria_list = transformers.StoppingCriteriaList() stopping_criteria_list = transformers.StoppingCriteriaList()
if stopping_string is not None: if type(stopping_strings) is list and len(stopping_strings) > 0:
# Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
t = encode(stopping_string, 0, add_special_tokens=False)
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0]))) stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
generate_params = {} generate_params = {}
@ -195,12 +200,10 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
if shared.soft_prompt: if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:])
if not (shared.args.chat or shared.args.cai_chat): if not (shared.args.chat or shared.args.cai_chat):
new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:])
reply = original_question + apply_extensions(reply, "output") reply = original_question + apply_extensions(reply, "output")
else:
reply = decode(output)
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
@ -223,12 +226,11 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
for output in generator: for output in generator:
if shared.soft_prompt: if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:])
if not (shared.args.chat or shared.args.cai_chat): if not (shared.args.chat or shared.args.cai_chat):
new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:])
reply = original_question + apply_extensions(reply, "output") reply = original_question + apply_extensions(reply, "output")
else:
reply = decode(output)
if output[-1] in eos_token_ids: if output[-1] in eos_token_ids:
break break
@ -244,12 +246,11 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
output = shared.model.generate(**generate_params)[0] output = shared.model.generate(**generate_params)[0]
if shared.soft_prompt: if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
new_tokens = len(output) - len(original_input_ids[0])
reply = decode(output[-new_tokens:])
if not (shared.args.chat or shared.args.cai_chat): if not (shared.args.chat or shared.args.cai_chat):
new_tokens = len(output) - len(original_input_ids[0])
reply = decode(output[-new_tokens:])
reply = original_question + apply_extensions(reply, "output") reply = original_question + apply_extensions(reply, "output")
else:
reply = decode(output)
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)): if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break break
@ -269,5 +270,5 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
traceback.print_exc() traceback.print_exc()
finally: finally:
t1 = time.time() t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)") print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens, context {len(original_input_ids[0])})")
return return

View File

@ -1,7 +1,7 @@
accelerate==0.17.1 accelerate==0.17.1
bitsandbytes==0.37.1 bitsandbytes==0.37.1
flexgen==0.1.7 flexgen==0.1.7
gradio==3.18.0 gradio==3.23.0
markdown markdown
numpy numpy
peft==0.2.0 peft==0.2.0

View File

@ -1,4 +1,3 @@
import gc
import io import io
import json import json
import re import re
@ -8,7 +7,6 @@ import zipfile
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
import torch
import modules.chat as chat import modules.chat as chat
import modules.extensions as extensions_module import modules.extensions as extensions_module
@ -17,7 +15,7 @@ import modules.ui as ui
from modules.html_generator import generate_chat_html from modules.html_generator import generate_chat_html
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt from modules.models import load_model, load_soft_prompt
from modules.text_generation import generate_reply from modules.text_generation import clear_torch_cache, generate_reply
# Loading custom settings # Loading custom settings
settings_file = None settings_file = None
@ -56,9 +54,7 @@ def load_model_wrapper(selected_model):
if selected_model != shared.model_name: if selected_model != shared.model_name:
shared.model_name = selected_model shared.model_name = selected_model
shared.model = shared.tokenizer = None shared.model = shared.tokenizer = None
if not shared.args.cpu: clear_torch_cache()
gc.collect()
torch.cuda.empty_cache()
shared.model, shared.tokenizer = load_model(shared.model_name) shared.model, shared.tokenizer = load_model(shared.model_name)
return selected_model return selected_model
@ -75,13 +71,8 @@ def unload_model():
print("Model weights unloaded.") print("Model weights unloaded.")
def load_lora_wrapper(selected_lora): def load_lora_wrapper(selected_lora):
shared.lora_name = selected_lora
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
if not shared.args.cpu:
gc.collect()
torch.cuda.empty_cache()
add_lora_to_model(selected_lora) add_lora_to_model(selected_lora)
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
return selected_lora, default_text return selected_lora, default_text
@ -258,14 +249,13 @@ else:
shared.model_name = available_models[i] shared.model_name = available_models[i]
shared.model, shared.tokenizer = load_model(shared.model_name) shared.model, shared.tokenizer = load_model(shared.model_name)
if shared.args.lora: if shared.args.lora:
print(shared.args.lora) add_lora_to_model(shared.args.lora)
shared.lora_name = shared.args.lora
add_lora_to_model(shared.lora_name)
# Default UI settings # Default UI settings
default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')] default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')] if shared.lora_name != "None":
if default_text == '': default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
else:
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')] default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
title ='Text generation web UI' title ='Text generation web UI'
description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n' description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
@ -354,7 +344,7 @@ def create_interface():
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events) shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events, queue=False)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream) shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
@ -395,19 +385,22 @@ def create_interface():
elif shared.args.notebook: elif shared.args.notebook:
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
with gr.Tab('Raw'):
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=25)
with gr.Tab('Markdown'):
shared.gradio['markdown'] = gr.Markdown()
with gr.Tab('HTML'):
shared.gradio['html'] = gr.HTML()
with gr.Row(): with gr.Row():
shared.gradio['Stop'] = gr.Button('Stop') with gr.Column(scale=4):
shared.gradio['Generate'] = gr.Button('Generate') with gr.Tab('Raw'):
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_id="textbox", lines=25)
with gr.Tab('Markdown'):
shared.gradio['markdown'] = gr.Markdown()
with gr.Tab('HTML'):
shared.gradio['html'] = gr.HTML()
create_model_and_preset_menus() with gr.Row():
shared.gradio['Stop'] = gr.Button('Stop')
shared.gradio['Generate'] = gr.Button('Generate')
with gr.Column(scale=1):
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
create_model_and_preset_menus()
with gr.Tab("Parameters", elem_id="parameters"): with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset) create_settings_menus(default_preset)

View File

@ -2,9 +2,9 @@
"max_new_tokens": 200, "max_new_tokens": 200,
"max_new_tokens_min": 1, "max_new_tokens_min": 1,
"max_new_tokens_max": 2000, "max_new_tokens_max": 2000,
"name1": "Person 1", "name1": "You",
"name2": "Person 2", "name2": "Assistant",
"context": "This is a conversation between two people.", "context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.",
"stop_at_newline": false, "stop_at_newline": false,
"chat_prompt_size": 2048, "chat_prompt_size": 2048,
"chat_prompt_size_min": 0, "chat_prompt_size_min": 0,