From d537b28d02782eaccf28741fb6a33ec2f39d1cf5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=CE=A6=CF=86?= <42910943+Brawlence@users.noreply.github.com>
Date: Tue, 14 Mar 2023 06:49:10 +0300
Subject: [PATCH 01/46] Extension: Stable Diffusion Api integration
Lets the bot answer you with a picture!
---
.gitignore | 1 +
extensions/sd_api_pictures/requirements.txt | 5 +
extensions/sd_api_pictures/script.py | 186 ++++++++++++++++++++
3 files changed, 192 insertions(+)
create mode 100644 extensions/sd_api_pictures/requirements.txt
create mode 100644 extensions/sd_api_pictures/script.py
diff --git a/.gitignore b/.gitignore
index 1b7f0fb8..80244a30 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,6 +2,7 @@ cache/*
characters/*
extensions/silero_tts/outputs/*
extensions/elevenlabs_tts/outputs/*
+extensions/sd_api_pictures/outputs/*
logs/*
models/*
softprompts/*
diff --git a/extensions/sd_api_pictures/requirements.txt b/extensions/sd_api_pictures/requirements.txt
new file mode 100644
index 00000000..5f94b3eb
--- /dev/null
+++ b/extensions/sd_api_pictures/requirements.txt
@@ -0,0 +1,5 @@
+gradio
+modules
+Pillow
+requests
+torch
\ No newline at end of file
diff --git a/extensions/sd_api_pictures/script.py b/extensions/sd_api_pictures/script.py
new file mode 100644
index 00000000..b6854a02
--- /dev/null
+++ b/extensions/sd_api_pictures/script.py
@@ -0,0 +1,186 @@
+import json
+import base64
+import requests
+import io
+
+from io import BytesIO
+from PIL import Image, PngImagePlugin
+from pathlib import Path
+
+import gradio as gr
+import torch
+
+import modules.chat as chat
+import modules.shared as shared
+
+torch._C._jit_set_profiling_mode(False)
+
+# parameters which can be customized in settings.json of webui
+params = {
+ 'enable_SD_api': False,
+ 'address': 'http://127.0.0.1:7860',
+ 'save_img': False,
+ 'SD_model': 'NeverEndingDream', # not really used right now
+ 'prompt_prefix': '(Masterpiece:1.1), (solo:1.3), detailed, intricate, colorful',
+ 'negative_prompt': '(worst quality, low quality:1.3)',
+ 'side_length': 512,
+ 'restore_faces': False
+}
+
+SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
+
+streaming_state = shared.args.no_stream # remember if chat streaming was enabled
+picture_response = False # specifies if the next model response should appear as a picture
+pic_id = 0
+
+def remove_surrounded_chars(string):
+ new_string = ""
+ in_star = False
+ for char in string:
+ if char == '*':
+ in_star = not in_star
+ elif not in_star:
+ new_string += char
+ return new_string
+
+# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
+def input_modifier(string):
+ """
+ This function is applied to your text inputs before
+ they are fed into the model.
+ """
+ global params, picture_response
+ if not params['enable_SD_api']:
+ return string
+
+ commands = ['send', 'mail', 'me']
+ mediums = ['image', 'pic', 'picture', 'photo']
+ subjects = ['yourself', 'own']
+ lowstr = string.lower()
+ if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
+ picture_response = True
+ shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
+ shared.processing_message = "*Is sending a picture...*"
+ string = "Please provide a detailed description of your surroundings, how you look and the situation you're in and what you are doing right now"
+ if any(target in lowstr for target in subjects): # the focus of the image should be on the sending character
+ string = "Please provide a detailed and vivid description of how you look and what you are wearing"
+
+ return string
+
+# Get and save the Stable Diffusion-generated picture
+def get_SD_pictures(description):
+
+ global params, pic_id
+
+ payload = {
+ "prompt": params['prompt_prefix'] + description,
+ "seed": -1,
+ "sampler_name": "DPM++ 2M Karras",
+ "steps": 32,
+ "cfg_scale": 7,
+ "width": params['side_length'],
+ "height": params['side_length'],
+ "restore_faces": params['restore_faces'],
+ "negative_prompt": params['negative_prompt']
+ }
+
+ response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
+ r = response.json()
+
+ visible_result = ""
+ for img_str in r['images']:
+ image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",",1)[0])))
+ if params['save_img']:
+ output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
+ image.save(output_file.as_posix())
+ pic_id += 1
+ # lower the resolution of received images for the chat, otherwise the history size gets out of control quickly with all the base64 values
+ newsize = (300, 300)
+ image = image.resize(newsize, Image.LANCZOS)
+ buffered = io.BytesIO()
+ image.save(buffered, format="JPEG")
+ buffered.seek(0)
+ image_bytes = buffered.getvalue()
+ img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
+ visible_result = visible_result + f'\n'
+
+ return visible_result
+
+# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
+# and replace it with 'text' for the purposes of logging?
+def output_modifier(string):
+ """
+ This function is applied to the model outputs.
+ """
+ global pic_id, picture_response, streaming_state
+
+ if not picture_response:
+ return string
+
+ string = remove_surrounded_chars(string)
+ string = string.replace('"', '')
+ string = string.replace('“', '')
+ string = string.replace('\n', ' ')
+ string = string.strip()
+
+ if string == '':
+ string = 'no viable description in reply, try regenerating'
+
+ # I can't for the love of all that's holy get the name from shared.gradio['name1'], so for now it will be like this
+ text = f'*Description: "{string}"*'
+
+ image = get_SD_pictures(string)
+
+ picture_response = False
+
+ shared.processing_message = "*Is typing...*"
+ shared.args.no_stream = streaming_state
+ return image + "\n" + text
+
+def bot_prefix_modifier(string):
+ """
+ This function is only applied in chat mode. It modifies
+ the prefix text for the Bot and can be used to bias its
+ behavior.
+ """
+
+ return string
+
+def force_pic():
+ global picture_response
+ picture_response = True
+
+def ui():
+
+ # Gradio elements
+ with gr.Accordion("Stable Diffusion api integration", open=True):
+ with gr.Row():
+ with gr.Column():
+ enable = gr.Checkbox(value=params['enable_SD_api'], label='Activate SD Api integration')
+ save_img = gr.Checkbox(value=params['save_img'], label='Keep original received images in the outputs subdir')
+ with gr.Column():
+ address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Stable Diffusion host address')
+
+ with gr.Row():
+ force_btn = gr.Button("Force the next response to be a picture")
+ generate_now_btn = gr.Button("Generate an image response to the input")
+
+ with gr.Accordion("Generation parameters", open=False):
+ prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
+ with gr.Row():
+ negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
+ dimensions = gr.Slider(256,702,value=params['side_length'],step=64,label='Image dimensions')
+ # model = gr.Dropdown(value=SD_models[0], choices=SD_models, label='Model')
+
+ # Event functions to update the parameters in the backend
+ enable.change(lambda x: params.update({"enable_SD_api": x}), enable, None)
+ save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
+ address.change(lambda x: params.update({"address": x}), address, None)
+ prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None)
+ negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None)
+ dimensions.change(lambda x: params.update({"side_length": x}), dimensions, None)
+ # model.change(lambda x: params.update({"SD_model": x}), model, None)
+
+ force_btn.click(force_pic)
+ generate_now_btn.click(force_pic)
+ generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
\ No newline at end of file
From c79fc69e95fbe2aac85ccd414b92b7b3a425bd81 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Sun, 19 Mar 2023 10:36:57 -0300
Subject: [PATCH 02/46] Fix the API example with streaming #417
---
api-example-stream.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/api-example-stream.py b/api-example-stream.py
index add1df41..055d605b 100644
--- a/api-example-stream.py
+++ b/api-example-stream.py
@@ -44,14 +44,14 @@ async def run(context):
case "send_hash":
await websocket.send(json.dumps({
"session_hash": session,
- "fn_index": 7
+ "fn_index": 9
}))
case "estimation":
pass
case "send_data":
await websocket.send(json.dumps({
"session_hash": session,
- "fn_index": 7,
+ "fn_index": 9,
"data": [
context,
params['max_new_tokens'],
From 7073e960933e8c2aad1881d57da1b42303cca528 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Sun, 19 Mar 2023 12:05:28 -0300
Subject: [PATCH 03/46] Add back RWKV dependency #98
---
requirements.txt | 1 +
1 file changed, 1 insertion(+)
diff --git a/requirements.txt b/requirements.txt
index b3a17ea4..e5b3de69 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,6 +6,7 @@ markdown
numpy
peft==0.2.0
requests
+rwkv==0.7.0
safetensors==0.3.0
sentencepiece
tqdm
From a78b6508fcc0f5b597365e7ff0fa1a9f9e43d8ad Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Sun, 19 Mar 2023 12:11:35 -0300
Subject: [PATCH 04/46] Make custom LoRAs work by default #385
---
modules/LoRA.py | 2 +-
modules/shared.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/modules/LoRA.py b/modules/LoRA.py
index f29523d2..6915e157 100644
--- a/modules/LoRA.py
+++ b/modules/LoRA.py
@@ -17,6 +17,6 @@ def add_lora_to_model(lora_name):
print(f"Adding the LoRA {lora_name} to the model...")
params = {}
- #params['device_map'] = {'': 0}
+ params['device_map'] = {'': 0}
#params['dtype'] = shared.model.dtype
shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params)
diff --git a/modules/shared.py b/modules/shared.py
index 2592ace7..e3920f22 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -56,7 +56,7 @@ settings = {
},
'lora_prompts': {
'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
- 'alpaca-lora-7b': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
+ '(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
}
}
From 257edf5f56ebe9765135509e3cf4833207c34138 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Sun, 19 Mar 2023 12:30:51 -0300
Subject: [PATCH 05/46] Make the Default preset more reasonable
Credits: anonymous 4chan user who got it off
"some twitter post or something someone linked,
who even knows anymore"
---
presets/Default.txt | 15 +++++----------
1 file changed, 5 insertions(+), 10 deletions(-)
diff --git a/presets/Default.txt b/presets/Default.txt
index 9f0983ec..d5283836 100644
--- a/presets/Default.txt
+++ b/presets/Default.txt
@@ -1,12 +1,7 @@
do_sample=True
-temperature=1
-top_p=1
-typical_p=1
-repetition_penalty=1
-top_k=50
-num_beams=1
-penalty_alpha=0
-min_length=0
-length_penalty=1
-no_repeat_ngram_size=0
+top_p=0.5
+top_k=40
+temperature=0.7
+repetition_penalty=1.2
+typical_p=1.0
early_stopping=False
From 4d701a6eb902919f35da40240d74a079d7a53df6 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Sun, 19 Mar 2023 12:51:47 -0300
Subject: [PATCH 06/46] Create a mirror for the preset menu
---
presets/Individual Today.txt | 6 ------
server.py | 9 +++++++--
2 files changed, 7 insertions(+), 8 deletions(-)
delete mode 100644 presets/Individual Today.txt
diff --git a/presets/Individual Today.txt b/presets/Individual Today.txt
deleted file mode 100644
index f40b879c..00000000
--- a/presets/Individual Today.txt
+++ /dev/null
@@ -1,6 +0,0 @@
-do_sample=True
-top_p=0.9
-top_k=50
-temperature=1.39
-repetition_penalty=1.08
-typical_p=0.2
diff --git a/server.py b/server.py
index 1d324fba..060f09d5 100644
--- a/server.py
+++ b/server.py
@@ -102,7 +102,7 @@ def load_preset_values(preset_menu, return_dict=False):
if return_dict:
return generate_params
else:
- return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
+ return preset_menu, generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
def upload_soft_prompt(file):
with zipfile.ZipFile(io.BytesIO(file)) as zf:
@@ -130,6 +130,10 @@ def create_model_and_preset_menus():
def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
+ with gr.Row():
+ shared.gradio['preset_menu_mirror'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
+ ui.create_refresh_button(shared.gradio['preset_menu_mirror'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
+
with gr.Row():
with gr.Column():
with gr.Box():
@@ -174,7 +178,8 @@ def create_settings_menus(default_preset):
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
- shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['encoder_repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']])
+ shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['preset_menu_mirror', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
+ shared.gradio['preset_menu_mirror'].change(load_preset_values, [shared.gradio['preset_menu_mirror']], [shared.gradio[k] for k in ['preset_menu', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True)
shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
From ddb62470e94ec78d70d9e138acfdb543b84dd331 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Sun, 19 Mar 2023 19:21:41 -0300
Subject: [PATCH 07/46] --no-cache and --gpu-memory in MiB for fine VRAM
control
---
README.md | 3 ++-
modules/models.py | 8 +++++---
modules/shared.py | 5 +++--
modules/text_generation.py | 4 +++-
4 files changed, 13 insertions(+), 7 deletions(-)
diff --git a/README.md b/README.md
index ded9b351..330f36b1 100644
--- a/README.md
+++ b/README.md
@@ -183,7 +183,8 @@ Optionally, you can use the following command-line flags:
| `--disk-cache-dir DISK_CACHE_DIR` | Directory to save the disk cache to. Defaults to `cache/`. |
| `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. |
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.|
-| `--flexgen` | Enable the use of FlexGen offloading. |
+| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
+| `--flexgen` | Enable the use of FlexGen offloading. |
| `--percent PERCENT [PERCENT ...]` | FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). |
| `--compress-weight` | FlexGen: Whether to compress weight (default: False).|
| `--pin-weight [PIN_WEIGHT]` | FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%). |
diff --git a/modules/models.py b/modules/models.py
index f07e738b..ccb97da3 100644
--- a/modules/models.py
+++ b/modules/models.py
@@ -1,5 +1,6 @@
import json
import os
+import re
import time
import zipfile
from pathlib import Path
@@ -120,11 +121,12 @@ def load_model(model_name):
params["torch_dtype"] = torch.float16
if shared.args.gpu_memory:
- memory_map = shared.args.gpu_memory
+ memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory))
+ max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
max_memory = {}
for i in range(len(memory_map)):
- max_memory[i] = f'{memory_map[i]}GiB'
- max_memory['cpu'] = f'{shared.args.cpu_memory or 99}GiB'
+ max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
+ max_memory['cpu'] = max_cpu_memory
params['max_memory'] = max_memory
elif shared.args.auto_devices:
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024*1024))
diff --git a/modules/shared.py b/modules/shared.py
index e3920f22..8cae1079 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -85,8 +85,9 @@ parser.add_argument('--bf16', action='store_true', help='Load the model with bfl
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directory to save the disk cache to. Defaults to "cache".')
-parser.add_argument('--gpu-memory', type=int, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs.')
-parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.')
+parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs.')
+parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.')
+parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.')
parser.add_argument('--percent', type=int, nargs="+", default=[0, 100, 100, 0, 100, 0], help='FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0).')
parser.add_argument("--compress-weight", action="store_true", help="FlexGen: activate weight compression.")
diff --git a/modules/text_generation.py b/modules/text_generation.py
index 1d11de12..9159975c 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -136,7 +136,9 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
t = encode(stopping_string, 0, add_special_tokens=False)
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
- generate_params = {}
+ generate_params = {
+ 'use_cache': not shared.args.no_cache,
+ }
if not shared.args.flexgen:
generate_params.update({
"max_new_tokens": max_new_tokens,
From b552d2b58a01cc18caf1664d7915940a1039de03 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Sun, 19 Mar 2023 19:24:41 -0300
Subject: [PATCH 08/46] Remove unused imports o
---
extensions/sd_api_pictures/script.py | 9 +++------
1 file changed, 3 insertions(+), 6 deletions(-)
diff --git a/extensions/sd_api_pictures/script.py b/extensions/sd_api_pictures/script.py
index b6854a02..b9fba2b9 100644
--- a/extensions/sd_api_pictures/script.py
+++ b/extensions/sd_api_pictures/script.py
@@ -1,14 +1,11 @@
-import json
import base64
-import requests
import io
-
-from io import BytesIO
-from PIL import Image, PngImagePlugin
from pathlib import Path
import gradio as gr
+import requests
import torch
+from PIL import Image
import modules.chat as chat
import modules.shared as shared
@@ -183,4 +180,4 @@ def ui():
force_btn.click(force_pic)
generate_now_btn.click(force_pic)
- generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
\ No newline at end of file
+ generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
From 7ddf6147accfb5b95e7dbbd7f1822cf976054a2a Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Sun, 19 Mar 2023 19:25:52 -0300
Subject: [PATCH 09/46] Update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 330f36b1..ce742c9e 100644
--- a/README.md
+++ b/README.md
@@ -202,7 +202,7 @@ Optionally, you can use the following command-line flags:
| `--auto-launch` | Open the web UI in the default browser upon launch. |
| `--verbose` | Print the prompts to the terminal. |
-Out of memory errors? [Check this guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide).
+Out of memory errors? [Check the low VRAM guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide).
## Presets
From 9378754cc7556e01ab7c7e54d512bc7fedef33fb Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Sun, 19 Mar 2023 20:14:50 -0300
Subject: [PATCH 10/46] Update README
---
README.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index ce742c9e..44c6516e 100644
--- a/README.md
+++ b/README.md
@@ -181,9 +181,9 @@ Optionally, you can use the following command-line flags:
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
| `--disk-cache-dir DISK_CACHE_DIR` | Directory to save the disk cache to. Defaults to `cache/`. |
-| `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. |
+| `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. You can also set values in MiB like `--gpu-memory 3500MiB`. |
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.|
-| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
+| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit at a performance cost. |
| `--flexgen` | Enable the use of FlexGen offloading. |
| `--percent PERCENT [PERCENT ...]` | FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). |
| `--compress-weight` | FlexGen: Whether to compress weight (default: False).|
From dd4374edde7b424dbbb0598d4cfafc931bcb63b2 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Sun, 19 Mar 2023 20:15:15 -0300
Subject: [PATCH 11/46] Update README
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 44c6516e..f6d86caf 100644
--- a/README.md
+++ b/README.md
@@ -183,7 +183,7 @@ Optionally, you can use the following command-line flags:
| `--disk-cache-dir DISK_CACHE_DIR` | Directory to save the disk cache to. Defaults to `cache/`. |
| `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. You can also set values in MiB like `--gpu-memory 3500MiB`. |
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.|
-| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit at a performance cost. |
+| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
| `--flexgen` | Enable the use of FlexGen offloading. |
| `--percent PERCENT [PERCENT ...]` | FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). |
| `--compress-weight` | FlexGen: Whether to compress weight (default: False).|
From 164e05daad924801a3b6b8a1fd55c5ed7a188919 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Sun, 19 Mar 2023 20:34:52 -0300
Subject: [PATCH 12/46] Download .py files using download-model.py
---
download-model.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/download-model.py b/download-model.py
index 808b9fc2..7c2965f6 100644
--- a/download-model.py
+++ b/download-model.py
@@ -117,7 +117,7 @@ def get_download_links_from_huggingface(model, branch):
is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
is_safetensors = re.match("model.*\.safetensors", fname)
is_tokenizer = re.match("tokenizer.*\.model", fname)
- is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
+ is_text = re.match(".*\.(txt|json|py)", fname) or is_tokenizer
if any((is_pytorch, is_safetensors, is_text, is_tokenizer)):
if is_text:
From 31ab2be8ef1ba29571de6a19cf0af87b01f42702 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Sun, 19 Mar 2023 22:10:55 -0300
Subject: [PATCH 13/46] Remove redundant requirements #309
---
extensions/sd_api_pictures/requirements.txt | 5 -----
1 file changed, 5 deletions(-)
delete mode 100644 extensions/sd_api_pictures/requirements.txt
diff --git a/extensions/sd_api_pictures/requirements.txt b/extensions/sd_api_pictures/requirements.txt
deleted file mode 100644
index 5f94b3eb..00000000
--- a/extensions/sd_api_pictures/requirements.txt
+++ /dev/null
@@ -1,5 +0,0 @@
-gradio
-modules
-Pillow
-requests
-torch
\ No newline at end of file
From a90f507abe0243949d2720a009283cc84d0e62ae Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 20 Mar 2023 11:49:42 -0300
Subject: [PATCH 14/46] Exit elevenlabs_tts if streaming is enabled
---
extensions/elevenlabs_tts/script.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/extensions/elevenlabs_tts/script.py b/extensions/elevenlabs_tts/script.py
index b8171063..6aa687ab 100644
--- a/extensions/elevenlabs_tts/script.py
+++ b/extensions/elevenlabs_tts/script.py
@@ -15,7 +15,10 @@ wav_idx = 0
user = ElevenLabsUser(params['api_key'])
user_info = None
-
+if not shared.args.no_stream:
+ print("Please add --no-stream. This extension is not meant to be used with streaming.")
+ raise ValueError
+
# Check if the API is valid and refresh the UI accordingly.
def check_valid_api():
From 75a7a84ef278cf24c5b59071f38c75ea5ab55aa4 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 20 Mar 2023 13:36:52 -0300
Subject: [PATCH 15/46] Exception handling (#454)
* Update text_generation.py
* Update extensions.py
---
modules/extensions.py | 3 +++
modules/text_generation.py | 5 +++++
2 files changed, 8 insertions(+)
diff --git a/modules/extensions.py b/modules/extensions.py
index 836fbc60..dbc93840 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -1,3 +1,5 @@
+import traceback
+
import gradio as gr
import extensions
@@ -17,6 +19,7 @@ def load_extensions():
print('Ok.')
except:
print('Fail.')
+ traceback.print_exc()
# This iterator returns the extensions in the order specified in the command-line
def iterator():
diff --git a/modules/text_generation.py b/modules/text_generation.py
index 9159975c..a70d490c 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -1,6 +1,7 @@
import gc
import re
import time
+import traceback
import numpy as np
import torch
@@ -110,6 +111,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
yield formatted_outputs(reply, shared.model_name)
+ except:
+ traceback.print_exc()
finally:
t1 = time.time()
output = encode(reply)[0]
@@ -243,6 +246,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
yield formatted_outputs(reply, shared.model_name)
+ except:
+ traceback.print_exc()
finally:
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")
From 536d0a4d93861fc40269fa6053c0893b7b501528 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 20 Mar 2023 14:00:40 -0300
Subject: [PATCH 16/46] Add an import
---
extensions/elevenlabs_tts/script.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/extensions/elevenlabs_tts/script.py b/extensions/elevenlabs_tts/script.py
index 6aa687ab..7339cc73 100644
--- a/extensions/elevenlabs_tts/script.py
+++ b/extensions/elevenlabs_tts/script.py
@@ -4,6 +4,8 @@ import gradio as gr
from elevenlabslib import ElevenLabsUser
from elevenlabslib.helpers import save_bytes_to_path
+import modules.shared as shared
+
params = {
'activate': True,
'api_key': '12345',
From 9a3bed50c3f51c505b7ea57433c8018c7375d535 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 20 Mar 2023 15:11:56 -0300
Subject: [PATCH 17/46] Attempt at fixing 4-bit with CPU offload
---
modules/GPTQ_loader.py | 15 ++++++++++-----
1 file changed, 10 insertions(+), 5 deletions(-)
diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py
index 662182e7..7045a098 100644
--- a/modules/GPTQ_loader.py
+++ b/modules/GPTQ_loader.py
@@ -1,3 +1,4 @@
+import re
import sys
from pathlib import Path
@@ -56,16 +57,20 @@ def load_quantized(model_name):
# Multiple GPUs or GPU+CPU
if shared.args.gpu_memory:
+ memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory))
+ max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
max_memory = {}
- for i in range(len(shared.args.gpu_memory)):
- max_memory[i] = f"{shared.args.gpu_memory[i]}GiB"
- max_memory['cpu'] = f"{shared.args.cpu_memory or '99'}GiB"
+ for i in range(len(memory_map)):
+ max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
+ max_memory['cpu'] = max_cpu_memory
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
- model = accelerate.dispatch_model(model, device_map=device_map)
+ print("Using the following device map for the 4-bit model:", device_map)
+ # https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
+ model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)
# Single GPU
- else:
+ elif not shared.args.cpu:
model = model.to(torch.device('cuda:0'))
return model
From 7618f3fe8c0d1fad6fdc6f7d99f0346b74c8e535 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 20 Mar 2023 16:30:56 -0300
Subject: [PATCH 18/46] Add -gptq-preload for 4-bit offloading (#460)
This works in a 4GB card now:
```
python server.py --model llama-7b-hf --gptq-bits 4 --gptq-pre-layer 20
```
---
modules/GPTQ_loader.py | 42 ++++++++++++++++++++++++------------------
modules/shared.py | 5 +++--
2 files changed, 27 insertions(+), 20 deletions(-)
diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py
index 7045a098..67899547 100644
--- a/modules/GPTQ_loader.py
+++ b/modules/GPTQ_loader.py
@@ -9,6 +9,7 @@ import modules.shared as shared
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
import llama
+import llama_inference_offload
import opt
@@ -24,7 +25,10 @@ def load_quantized(model_name):
model_type = shared.args.gptq_model_type.lower()
if model_type == 'llama':
- load_quant = llama.load_quant
+ if not shared.args.gptq_pre_layer:
+ load_quant = llama.load_quant
+ else:
+ load_quant = llama_inference_offload.load_quant
elif model_type == 'opt':
load_quant = opt.load_quant
else:
@@ -53,24 +57,26 @@ def load_quantized(model_name):
print(f"Could not find {pt_model}, exiting...")
exit()
- model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits)
+ # Using qwopqwop200's offload
+ if shared.args.gptq_pre_layer:
+ model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits, shared.args.gptq_pre_layer)
+ else:
+ model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits)
- # Multiple GPUs or GPU+CPU
- if shared.args.gpu_memory:
- memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory))
- max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
- max_memory = {}
- for i in range(len(memory_map)):
- max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
- max_memory['cpu'] = max_cpu_memory
+ # Using accelerate offload (doesn't work properly)
+ if shared.args.gpu_memory:
+ memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory))
+ max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
+ max_memory = {}
+ for i in range(len(memory_map)):
+ max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
+ max_memory['cpu'] = max_cpu_memory
- device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
- print("Using the following device map for the 4-bit model:", device_map)
- # https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
- model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)
-
- # Single GPU
- elif not shared.args.cpu:
- model = model.to(torch.device('cuda:0'))
+ device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
+ print("Using the following device map for the 4-bit model:", device_map)
+ # https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
+ model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)
+ elif not shared.args.cpu:
+ model = model.to(torch.device('cuda:0'))
return model
diff --git a/modules/shared.py b/modules/shared.py
index 8cae1079..8d591f4f 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -79,8 +79,9 @@ parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI i
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--load-in-4bit', action='store_true', help='DEPRECATED: use --gptq-bits 4 instead.')
-parser.add_argument('--gptq-bits', type=int, default=0, help='Load a pre-quantized model with specified precision. 2, 3, 4 and 8bit are supported. Currently only works with LLaMA and OPT.')
-parser.add_argument('--gptq-model-type', type=str, help='Model type of pre-quantized model. Currently only LLaMa and OPT are supported.')
+parser.add_argument('--gptq-bits', type=int, default=0, help='GPTQ: Load a pre-quantized model with specified precision. 2, 3, 4 and 8bit are supported. Currently only works with LLaMA and OPT.')
+parser.add_argument('--gptq-model-type', type=str, help='GPTQ: Model type of pre-quantized model. Currently only LLaMa and OPT are supported.')
+parser.add_argument('--gptq-pre-layer', type=int, default=0, help='GPTQ: The number of layers to preload.')
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
From db4219a340540d161f8864be914dc1bdade97e6d Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 20 Mar 2023 16:40:08 -0300
Subject: [PATCH 19/46] Update comments
---
modules/GPTQ_loader.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py
index 67899547..32a5458f 100644
--- a/modules/GPTQ_loader.py
+++ b/modules/GPTQ_loader.py
@@ -57,13 +57,13 @@ def load_quantized(model_name):
print(f"Could not find {pt_model}, exiting...")
exit()
- # Using qwopqwop200's offload
+ # qwopqwop200's offload
if shared.args.gptq_pre_layer:
model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits, shared.args.gptq_pre_layer)
else:
model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits)
- # Using accelerate offload (doesn't work properly)
+ # accelerate offload (doesn't work properly)
if shared.args.gpu_memory:
memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory))
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
@@ -76,6 +76,8 @@ def load_quantized(model_name):
print("Using the following device map for the 4-bit model:", device_map)
# https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)
+
+ # No offload
elif not shared.args.cpu:
model = model.to(torch.device('cuda:0'))
From 6872ffd9769e45f46e9f57ee7de704f025b811c0 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 20 Mar 2023 16:53:14 -0300
Subject: [PATCH 20/46] Update README.md
---
README.md | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index f6d86caf..cb070445 100644
--- a/README.md
+++ b/README.md
@@ -175,8 +175,9 @@ Optionally, you can use the following command-line flags:
| `--cpu` | Use the CPU to generate text.|
| `--load-in-8bit` | Load the model with 8-bit precision.|
| `--load-in-4bit` | DEPRECATED: use `--gptq-bits 4` instead. |
-| `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA and OPT. |
-| `--gptq-model-type MODEL_TYPE` | Model type of pre-quantized model. Currently only LLaMa and OPT are supported. |
+| `--gptq-bits GPTQ_BITS` | GPTQ: Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA and OPT. |
+| `--gptq-model-type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently only LLaMa and OPT are supported. |
+| `--gptq-pre-layer GPTQ_PRE_LAYER` | GPTQ: The number of layers to preload. |
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
From 45b7e53565ee0151cdbe2aa5760cfdf05f696d5c Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 20 Mar 2023 20:36:02 -0300
Subject: [PATCH 21/46] Only catch proper Exceptions in the text generation
function
---
modules/text_generation.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/modules/text_generation.py b/modules/text_generation.py
index a70d490c..84752b39 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -111,7 +111,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
yield formatted_outputs(reply, shared.model_name)
- except:
+ except Exception:
traceback.print_exc()
finally:
t1 = time.time()
@@ -246,7 +246,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
yield formatted_outputs(reply, shared.model_name)
- except:
+ except Exception:
traceback.print_exc()
finally:
t1 = time.time()
From 5389fce8e11a6018c44cfea4b29a1a0216ab5687 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=CE=A6=CF=86?= <42910943+Brawlence@users.noreply.github.com>
Date: Wed, 22 Mar 2023 07:47:54 +0300
Subject: [PATCH 22/46] Extensions performance & memory optimisations
Reworked remove_surrounded_chars() to use regular expression ( https://regexr.com/7alb5 ) instead of repeated string concatenations for elevenlab_tts, silero_tts, sd_api_pictures. This should be both faster and more robust in handling asterisks.
Reduced the memory footprint of send_pictures and sd_api_pictures by scaling the images in the chat to 300 pixels max-side wise. (The user already has the original in case of the sent picture and there's an option to save the SD generation).
This should fix history growing annoyingly large with multiple pictures present
---
extensions/elevenlabs_tts/script.py | 14 ++++++--------
extensions/sd_api_pictures/script.py | 26 ++++++++++++++++----------
extensions/send_pictures/script.py | 13 ++++++++++++-
extensions/silero_tts/script.py | 13 +++++--------
4 files changed, 39 insertions(+), 27 deletions(-)
diff --git a/extensions/elevenlabs_tts/script.py b/extensions/elevenlabs_tts/script.py
index 7339cc73..cee64c06 100644
--- a/extensions/elevenlabs_tts/script.py
+++ b/extensions/elevenlabs_tts/script.py
@@ -4,6 +4,8 @@ import gradio as gr
from elevenlabslib import ElevenLabsUser
from elevenlabslib.helpers import save_bytes_to_path
+import re
+
import modules.shared as shared
params = {
@@ -52,14 +54,10 @@ def refresh_voices():
return
def remove_surrounded_chars(string):
- new_string = ""
- in_star = False
- for char in string:
- if char == '*':
- in_star = not in_star
- elif not in_star:
- new_string += char
- return new_string
+ # regexp is way faster than repeated string concatenation!
+ # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
+ # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
+ return re.sub('\*[^\*]*?(\*|$)','',string)
def input_modifier(string):
"""
diff --git a/extensions/sd_api_pictures/script.py b/extensions/sd_api_pictures/script.py
index b9fba2b9..03d3c784 100644
--- a/extensions/sd_api_pictures/script.py
+++ b/extensions/sd_api_pictures/script.py
@@ -1,5 +1,6 @@
import base64
import io
+import re
from pathlib import Path
import gradio as gr
@@ -31,14 +32,10 @@ picture_response = False # specifies if the next model response should appear as
pic_id = 0
def remove_surrounded_chars(string):
- new_string = ""
- in_star = False
- for char in string:
- if char == '*':
- in_star = not in_star
- elif not in_star:
- new_string += char
- return new_string
+ # regexp is way faster than repeated string concatenation!
+ # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
+ # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
+ return re.sub('\*[^\*]*?(\*|$)','',string)
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
def input_modifier(string):
@@ -54,6 +51,8 @@ def input_modifier(string):
mediums = ['image', 'pic', 'picture', 'photo']
subjects = ['yourself', 'own']
lowstr = string.lower()
+
+ # TODO: refactor out to separate handler and also replace detection with a regexp
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
picture_response = True
shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
@@ -91,8 +90,15 @@ def get_SD_pictures(description):
output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
image.save(output_file.as_posix())
pic_id += 1
- # lower the resolution of received images for the chat, otherwise the history size gets out of control quickly with all the base64 values
- newsize = (300, 300)
+ # lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
+ width, height = image.size
+ if (width > 300):
+ height = int(height * (300 / width))
+ width = 300
+ elif (height > 300):
+ width = int(width * (300 / height))
+ height = 300
+ newsize = (width, height)
image = image.resize(newsize, Image.LANCZOS)
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py
index b0c35632..05ceed8b 100644
--- a/extensions/send_pictures/script.py
+++ b/extensions/send_pictures/script.py
@@ -4,6 +4,7 @@ from io import BytesIO
import gradio as gr
import torch
from transformers import BlipForConditionalGeneration, BlipProcessor
+from PIL import Image
import modules.chat as chat
import modules.shared as shared
@@ -25,10 +26,20 @@ def caption_image(raw_image):
def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
+ # lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
+ width, height = picture.size
+ if (width > 300):
+ height = int(height * (300 / width))
+ width = 300
+ elif (height > 300):
+ width = int(width * (300 / height))
+ height = 300
+ newsize = (width, height)
+ picture = picture.resize(newsize, Image.LANCZOS)
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
- visible_text = f''
+ visible_text = f''
return text, visible_text
def ui():
diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py
index f611dc27..447c7033 100644
--- a/extensions/silero_tts/script.py
+++ b/extensions/silero_tts/script.py
@@ -3,6 +3,7 @@ from pathlib import Path
import gradio as gr
import torch
+import re
import modules.chat as chat
import modules.shared as shared
@@ -46,14 +47,10 @@ def load_model():
model = load_model()
def remove_surrounded_chars(string):
- new_string = ""
- in_star = False
- for char in string:
- if char == '*':
- in_star = not in_star
- elif not in_star:
- new_string += char
- return new_string
+ # regexp is way faster than repeated string concatenation!
+ # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
+ # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
+ return re.sub('\*[^\*]*?(\*|$)','',string)
def remove_tts_from_history(name1, name2):
for i, entry in enumerate(shared.history['internal']):
From 61346b88eac3e1664c81de842168ea8570c26707 Mon Sep 17 00:00:00 2001
From: wywywywy
Date: Wed, 22 Mar 2023 18:40:20 +0000
Subject: [PATCH 23/46] Add "seed" menu in the Parameters tab
---
api-example-stream.py | 6 ++++--
api-example.py | 2 ++
extensions/api/script.py | 3 ++-
modules/chat.py | 16 ++++++++--------
modules/text_generation.py | 9 ++++++++-
server.py | 16 +++++++++-------
6 files changed, 33 insertions(+), 19 deletions(-)
diff --git a/api-example-stream.py b/api-example-stream.py
index 055d605b..e87fb74c 100644
--- a/api-example-stream.py
+++ b/api-example-stream.py
@@ -34,6 +34,7 @@ async def run(context):
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
+ 'seed': -1,
}
session = random_hash()
@@ -44,14 +45,14 @@ async def run(context):
case "send_hash":
await websocket.send(json.dumps({
"session_hash": session,
- "fn_index": 9
+ "fn_index": 12
}))
case "estimation":
pass
case "send_data":
await websocket.send(json.dumps({
"session_hash": session,
- "fn_index": 9,
+ "fn_index": 12,
"data": [
context,
params['max_new_tokens'],
@@ -68,6 +69,7 @@ async def run(context):
params['penalty_alpha'],
params['length_penalty'],
params['early_stopping'],
+ params['seed'],
]
}))
case "process_starts":
diff --git a/api-example.py b/api-example.py
index a6f0c10e..0349824b 100644
--- a/api-example.py
+++ b/api-example.py
@@ -32,6 +32,7 @@ params = {
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
+ 'seed': -1,
}
# Input prompt
@@ -54,6 +55,7 @@ response = requests.post(f"http://{server}:7860/run/textgen", json={
params['penalty_alpha'],
params['length_penalty'],
params['early_stopping'],
+ params['seed'],
]
}).json()
diff --git a/extensions/api/script.py b/extensions/api/script.py
index 53e47f3f..1774c345 100644
--- a/extensions/api/script.py
+++ b/extensions/api/script.py
@@ -56,6 +56,7 @@ class Handler(BaseHTTPRequestHandler):
penalty_alpha=0,
length_penalty=1,
early_stopping=False,
+ seed=-1,
)
answer = ''
@@ -87,4 +88,4 @@ def run_server():
server.serve_forever()
def ui():
- Thread(target=run_server, daemon=True).start()
\ No newline at end of file
+ Thread(target=run_server, daemon=True).start()
diff --git a/modules/chat.py b/modules/chat.py
index 36265990..78fc4ab5 100644
--- a/modules/chat.py
+++ b/modules/chat.py
@@ -91,7 +91,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
def stop_everything_event():
shared.stop_everything = True
-def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
+def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
shared.stop_everything = False
just_started = True
eos_token = '\n' if check else None
@@ -127,7 +127,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
# Generate
reply = ''
for i in range(chat_generation_attempts):
- for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
+ for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_string=f"\n{name1}:"):
# Extracting the reply
reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check)
@@ -154,7 +154,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
yield shared.history['visible']
-def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
+def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
eos_token = '\n' if check else None
if 'pygmalion' in shared.model_name.lower():
@@ -166,18 +166,18 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
# Yield *Is typing...*
yield shared.processing_message
for i in range(chat_generation_attempts):
- for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
+ for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_string=f"\n{name2}:"):
reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
yield reply
if next_character_found:
break
yield reply
-def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
- for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
+def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
+ for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
yield generate_chat_html(_history, name1, name2, shared.character)
-def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
+def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
else:
@@ -185,7 +185,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
last_internal = shared.history['internal'].pop()
# Yield '*Is typing...*'
yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character)
- for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
+ for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
if shared.args.cai_chat:
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
else:
diff --git a/modules/text_generation.py b/modules/text_generation.py
index 84752b39..610bd4fc 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -93,8 +93,15 @@ def clear_torch_cache():
if not shared.args.cpu:
torch.cuda.empty_cache()
-def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
+def set_manual_seed(seed):
+ if seed != -1:
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+
+def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_string=None):
clear_torch_cache()
+ set_manual_seed(seed)
t0 = time.time()
# These models are not part of Hugging Face, so we handle them
diff --git a/server.py b/server.py
index 060f09d5..cdf7aa93 100644
--- a/server.py
+++ b/server.py
@@ -130,10 +130,6 @@ def create_model_and_preset_menus():
def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
- with gr.Row():
- shared.gradio['preset_menu_mirror'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
- ui.create_refresh_button(shared.gradio['preset_menu_mirror'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
-
with gr.Row():
with gr.Column():
with gr.Box():
@@ -164,6 +160,12 @@ def create_settings_menus(default_preset):
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
+ shared.gradio['seed'] = gr.Number(value=-1, label='Seed (-1 for random)')
+
+ with gr.Row():
+ shared.gradio['preset_menu_mirror'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
+ ui.create_refresh_button(shared.gradio['preset_menu_mirror'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
+
with gr.Row():
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
@@ -330,7 +332,7 @@ def create_interface():
create_settings_menus(default_preset)
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
- shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
+ shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
@@ -393,7 +395,7 @@ def create_interface():
with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset)
- shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
+ shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
@@ -425,7 +427,7 @@ def create_interface():
with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset)
- shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
+ shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
From 104212529ff8c6b41baf12bfd4673fe782ed04d7 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Wed, 22 Mar 2023 15:55:03 -0300
Subject: [PATCH 24/46] Minor changes
---
extensions/elevenlabs_tts/script.py | 9 +++------
extensions/sd_api_pictures/script.py | 8 +++-----
extensions/send_pictures/script.py | 9 ++++-----
extensions/silero_tts/script.py | 8 +++-----
4 files changed, 13 insertions(+), 21 deletions(-)
diff --git a/extensions/elevenlabs_tts/script.py b/extensions/elevenlabs_tts/script.py
index cee64c06..2e8b184f 100644
--- a/extensions/elevenlabs_tts/script.py
+++ b/extensions/elevenlabs_tts/script.py
@@ -1,13 +1,11 @@
+import re
from pathlib import Path
import gradio as gr
+import modules.shared as shared
from elevenlabslib import ElevenLabsUser
from elevenlabslib.helpers import save_bytes_to_path
-import re
-
-import modules.shared as shared
-
params = {
'activate': True,
'api_key': '12345',
@@ -54,7 +52,6 @@ def refresh_voices():
return
def remove_surrounded_chars(string):
- # regexp is way faster than repeated string concatenation!
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)','',string)
@@ -113,4 +110,4 @@ def ui():
voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
api_key.change(lambda x: params.update({'api_key': x}), api_key, None)
connect.click(check_valid_api, [], connection_status)
- connect.click(refresh_voices, [], voice)
+ connect.click(refresh_voices, [], voice)
\ No newline at end of file
diff --git a/extensions/sd_api_pictures/script.py b/extensions/sd_api_pictures/script.py
index 03d3c784..1f6ba2d2 100644
--- a/extensions/sd_api_pictures/script.py
+++ b/extensions/sd_api_pictures/script.py
@@ -4,13 +4,12 @@ import re
from pathlib import Path
import gradio as gr
+import modules.chat as chat
+import modules.shared as shared
import requests
import torch
from PIL import Image
-import modules.chat as chat
-import modules.shared as shared
-
torch._C._jit_set_profiling_mode(False)
# parameters which can be customized in settings.json of webui
@@ -32,7 +31,6 @@ picture_response = False # specifies if the next model response should appear as
pic_id = 0
def remove_surrounded_chars(string):
- # regexp is way faster than repeated string concatenation!
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)','',string)
@@ -186,4 +184,4 @@ def ui():
force_btn.click(force_pic)
generate_now_btn.click(force_pic)
- generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
+ generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
\ No newline at end of file
diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py
index 05ceed8b..46393e6c 100644
--- a/extensions/send_pictures/script.py
+++ b/extensions/send_pictures/script.py
@@ -2,12 +2,11 @@ import base64
from io import BytesIO
import gradio as gr
-import torch
-from transformers import BlipForConditionalGeneration, BlipProcessor
-from PIL import Image
-
import modules.chat as chat
import modules.shared as shared
+import torch
+from PIL import Image
+from transformers import BlipForConditionalGeneration, BlipProcessor
# If 'state' is True, will hijack the next chat generation with
# custom input text given by 'value' in the format [text, visible_text]
@@ -54,4 +53,4 @@ def ui():
picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
# Clear the picture from the upload field
- picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
+ picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
\ No newline at end of file
diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py
index 447c7033..a81a5da1 100644
--- a/extensions/silero_tts/script.py
+++ b/extensions/silero_tts/script.py
@@ -1,12 +1,11 @@
+import re
import time
from pathlib import Path
import gradio as gr
-import torch
-import re
-
import modules.chat as chat
import modules.shared as shared
+import torch
torch._C._jit_set_profiling_mode(False)
@@ -47,7 +46,6 @@ def load_model():
model = load_model()
def remove_surrounded_chars(string):
- # regexp is way faster than repeated string concatenation!
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)','',string)
@@ -163,4 +161,4 @@ def ui():
autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
voice.change(lambda x: params.update({"speaker": x}), voice, None)
v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
- v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)
+ v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)
\ No newline at end of file
From 0abff499e2e40fb308b406b82ad3f850683720e3 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Wed, 22 Mar 2023 16:03:05 -0300
Subject: [PATCH 25/46] Use image.thumbnail
---
extensions/sd_api_pictures/script.py | 10 +---------
extensions/send_pictures/script.py | 10 +---------
2 files changed, 2 insertions(+), 18 deletions(-)
diff --git a/extensions/sd_api_pictures/script.py b/extensions/sd_api_pictures/script.py
index 1f6ba2d2..cc85f3b3 100644
--- a/extensions/sd_api_pictures/script.py
+++ b/extensions/sd_api_pictures/script.py
@@ -89,15 +89,7 @@ def get_SD_pictures(description):
image.save(output_file.as_posix())
pic_id += 1
# lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
- width, height = image.size
- if (width > 300):
- height = int(height * (300 / width))
- width = 300
- elif (height > 300):
- width = int(width * (300 / height))
- height = 300
- newsize = (width, height)
- image = image.resize(newsize, Image.LANCZOS)
+ image.thumbnail((300, 300))
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
buffered.seek(0)
diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py
index 46393e6c..196c7d53 100644
--- a/extensions/send_pictures/script.py
+++ b/extensions/send_pictures/script.py
@@ -26,15 +26,7 @@ def caption_image(raw_image):
def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
- width, height = picture.size
- if (width > 300):
- height = int(height * (300 / width))
- width = 300
- elif (height > 300):
- width = int(width * (300 / height))
- height = 300
- newsize = (width, height)
- picture = picture.resize(newsize, Image.LANCZOS)
+ image.thumbnail((300, 300))
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
From bfb1be2820417fb5d2e843d9e164862b8962446d Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Wed, 22 Mar 2023 16:09:48 -0300
Subject: [PATCH 26/46] Minor fix
---
extensions/send_pictures/script.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py
index 196c7d53..556a88e5 100644
--- a/extensions/send_pictures/script.py
+++ b/extensions/send_pictures/script.py
@@ -26,7 +26,7 @@ def caption_image(raw_image):
def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
- image.thumbnail((300, 300))
+ picture.thumbnail((300, 300))
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
@@ -45,4 +45,4 @@ def ui():
picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
# Clear the picture from the upload field
- picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
\ No newline at end of file
+ picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
From de6a09dc7f7d5a5d8496cfa1598abb4ff5ee1338 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 23 Mar 2023 00:12:40 -0300
Subject: [PATCH 27/46] Properly separate the original prompt from the reply
---
modules/text_generation.py | 30 +++++++++++++++++++-----------
1 file changed, 19 insertions(+), 11 deletions(-)
diff --git a/modules/text_generation.py b/modules/text_generation.py
index 610bd4fc..d539f6d4 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -136,6 +136,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
input_ids = encode(question, max_new_tokens)
original_input_ids = input_ids
output = input_ids[0]
+
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
if eos_token is not None:
@@ -146,9 +147,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
t = encode(stopping_string, 0, add_special_tokens=False)
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
- generate_params = {
- 'use_cache': not shared.args.no_cache,
- }
+ generate_params = {}
if not shared.args.flexgen:
generate_params.update({
"max_new_tokens": max_new_tokens,
@@ -175,6 +174,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
"temperature": temperature,
"stop": eos_token_ids[-1],
})
+ if shared.args.no_cache:
+ generate_params.update({"use_cache": False})
if shared.args.deepspeed:
generate_params.update({"synced_gpus": True})
if shared.soft_prompt:
@@ -194,9 +195,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
- reply = decode(output)
if not (shared.args.chat or shared.args.cai_chat):
- reply = original_question + apply_extensions(reply[len(question):], "output")
+ new_tokens = len(output) - len(input_ids[0])
+ reply = decode(output[-new_tokens:])
+ reply = original_question + apply_extensions(reply, "output")
+ else:
+ reply = decode(output)
yield formatted_outputs(reply, shared.model_name)
@@ -219,10 +223,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
for output in generator:
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
- reply = decode(output)
-
if not (shared.args.chat or shared.args.cai_chat):
- reply = original_question + apply_extensions(reply[len(question):], "output")
+ new_tokens = len(output) - len(input_ids[0])
+ reply = decode(output[-new_tokens:])
+ reply = original_question + apply_extensions(reply, "output")
+ else:
+ reply = decode(output)
if output[-1] in eos_token_ids:
break
@@ -238,10 +244,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
output = shared.model.generate(**generate_params)[0]
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
- reply = decode(output)
-
if not (shared.args.chat or shared.args.cai_chat):
- reply = original_question + apply_extensions(reply[len(question):], "output")
+ new_tokens = len(output) - len(original_input_ids[0])
+ reply = decode(output[-new_tokens:])
+ reply = original_question + apply_extensions(reply, "output")
+ else:
+ reply = decode(output)
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break
From 7b6f85d3276480c90677bf55f33fa9f8fa7a2037 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 23 Mar 2023 00:13:34 -0300
Subject: [PATCH 28/46] Fix markdown headers in light mode
---
css/main.css | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/css/main.css b/css/main.css
index c6b0b07e..09f3b6a8 100644
--- a/css/main.css
+++ b/css/main.css
@@ -50,3 +50,7 @@ ol li p, ul li p {
#main, #parameters, #chat-settings, #interface-mode, #lora {
border: 0;
}
+
+.gradio-container-3-18-0 .prose * h1, h2, h3, h4 {
+ color: white;
+}
From bfa81e105e2b9075d390fb23e4b894a8456444e5 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 23 Mar 2023 00:22:14 -0300
Subject: [PATCH 29/46] Fix FlexGen streaming
---
modules/text_generation.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/modules/text_generation.py b/modules/text_generation.py
index d539f6d4..e738cb21 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -258,6 +258,10 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
input_ids = np.reshape(output, (1, output.shape[0]))
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
+ generate_params.update({"inputs_embeds": inputs_embeds})
+ generate_params.update({"inputs": filler_input_ids})
+ else:
+ generate_params.update({"inputs": input_ids})
yield formatted_outputs(reply, shared.model_name)
From eac27f4f556b2e4fd149e65e2395fbc9ce2ea3c7 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 23 Mar 2023 00:55:33 -0300
Subject: [PATCH 30/46] Make LoRAs work in 16-bit mode
---
modules/LoRA.py | 13 +++++++++----
1 file changed, 9 insertions(+), 4 deletions(-)
diff --git a/modules/LoRA.py b/modules/LoRA.py
index 6915e157..20850338 100644
--- a/modules/LoRA.py
+++ b/modules/LoRA.py
@@ -13,10 +13,15 @@ def add_lora_to_model(lora_name):
print("Reloading the model to remove the LoRA...")
shared.model, shared.tokenizer = load_model(shared.model_name)
else:
- # Why doesn't this work in 16-bit mode?
print(f"Adding the LoRA {lora_name} to the model...")
-
+
params = {}
- params['device_map'] = {'': 0}
- #params['dtype'] = shared.model.dtype
+ if shared.args.load_in_8bit:
+ params['device_map'] = {'': 0}
+ else:
+ params['device_map'] = 'auto'
+ params['dtype'] = shared.model.dtype
+
shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params)
+ if not shared.args.load_in_8bit:
+ shared.model.half()
From 29bd41d453cc8404b7183af685cdd4b952e96435 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 23 Mar 2023 01:05:13 -0300
Subject: [PATCH 31/46] Fix LoRA in CPU mode
---
modules/LoRA.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/modules/LoRA.py b/modules/LoRA.py
index 20850338..0a2aaa7d 100644
--- a/modules/LoRA.py
+++ b/modules/LoRA.py
@@ -18,10 +18,10 @@ def add_lora_to_model(lora_name):
params = {}
if shared.args.load_in_8bit:
params['device_map'] = {'': 0}
- else:
+ elif not shared.args.cpu:
params['device_map'] = 'auto'
params['dtype'] = shared.model.dtype
shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params)
- if not shared.args.load_in_8bit:
+ if not shared.args.load_in_8bit and not shared.args.cpu:
shared.model.half()
From c5ebcc5f7e862b1f2c6b1d807bbf2c1aadeb159e Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 23 Mar 2023 13:36:00 -0300
Subject: [PATCH 32/46] Change the default names (#518)
* Update shared.py
* Update settings-template.json
---
modules/shared.py | 6 +++---
settings-template.json | 6 +++---
2 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/modules/shared.py b/modules/shared.py
index 8d591f4f..720c697e 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -27,9 +27,9 @@ settings = {
'max_new_tokens': 200,
'max_new_tokens_min': 1,
'max_new_tokens_max': 2000,
- 'name1': 'Person 1',
- 'name2': 'Person 2',
- 'context': 'This is a conversation between two people.',
+ 'name1': 'You',
+ 'name2': 'Assistant',
+ 'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
'stop_at_newline': False,
'chat_prompt_size': 2048,
'chat_prompt_size_min': 0,
diff --git a/settings-template.json b/settings-template.json
index 7a7de7af..79fd5023 100644
--- a/settings-template.json
+++ b/settings-template.json
@@ -2,9 +2,9 @@
"max_new_tokens": 200,
"max_new_tokens_min": 1,
"max_new_tokens_max": 2000,
- "name1": "Person 1",
- "name2": "Person 2",
- "context": "This is a conversation between two people.",
+ "name1": "You",
+ "name2": "Assistant",
+ "context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.",
"stop_at_newline": false,
"chat_prompt_size": 2048,
"chat_prompt_size_min": 0,
From 9bf6ecf9e2de9b72c3fa62e0e6f5b5e9041825b1 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 23 Mar 2023 16:49:41 -0300
Subject: [PATCH 33/46] Fix LoRA device map (attempt)
---
modules/LoRA.py | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/modules/LoRA.py b/modules/LoRA.py
index 0a2aaa7d..5f77e340 100644
--- a/modules/LoRA.py
+++ b/modules/LoRA.py
@@ -16,12 +16,15 @@ def add_lora_to_model(lora_name):
print(f"Adding the LoRA {lora_name} to the model...")
params = {}
- if shared.args.load_in_8bit:
- params['device_map'] = {'': 0}
- elif not shared.args.cpu:
- params['device_map'] = 'auto'
+ if not shared.args.cpu:
params['dtype'] = shared.model.dtype
+ if hasattr(shared.model, "hf_device_map"):
+ params['device_map'] = {"base_model.model."+k: v for k, v in shared.model.hf_device_map.items()}
+ elif shared.args.load_in_8bit:
+ params['device_map'] = {'': 0}
shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params)
if not shared.args.load_in_8bit and not shared.args.cpu:
shared.model.half()
+ if not hasattr(shared.model, "hf_device_map"):
+ shared.model.cuda()
From 4578e88ffd77dc249fa97d0ec8cb667b21089ba8 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 23 Mar 2023 21:38:20 -0300
Subject: [PATCH 34/46] Stop the bot from talking for you in chat mode
---
modules/RWKV.py | 4 ++--
modules/callbacks.py | 20 ++++++++---------
modules/chat.py | 44 ++++++++++++++-----------------------
modules/text_generation.py | 45 +++++++++++++++++++-------------------
4 files changed, 51 insertions(+), 62 deletions(-)
diff --git a/modules/RWKV.py b/modules/RWKV.py
index 5cf8937a..8c7ea2b9 100644
--- a/modules/RWKV.py
+++ b/modules/RWKV.py
@@ -45,11 +45,11 @@ class RWKVModel:
token_stop = token_stop
)
- return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
+ return self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
def generate_with_streaming(self, **kwargs):
with Iteratorize(self.generate, kwargs, callback=None) as generator:
- reply = kwargs['context']
+ reply = ''
for token in generator:
reply += token
yield reply
diff --git a/modules/callbacks.py b/modules/callbacks.py
index 12a90cc3..2ae9d908 100644
--- a/modules/callbacks.py
+++ b/modules/callbacks.py
@@ -11,24 +11,22 @@ import modules.shared as shared
# Copied from https://github.com/PygmalionAI/gradio-ui/
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
- def __init__(self, sentinel_token_ids: torch.LongTensor,
- starting_idx: int):
+ def __init__(self, sentinel_token_ids: list[torch.LongTensor], starting_idx: int):
transformers.StoppingCriteria.__init__(self)
self.sentinel_token_ids = sentinel_token_ids
self.starting_idx = starting_idx
- def __call__(self, input_ids: torch.LongTensor,
- _scores: torch.FloatTensor) -> bool:
+ def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
for sample in input_ids:
trimmed_sample = sample[self.starting_idx:]
- # Can't unfold, output is still too tiny. Skip.
- if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
- continue
- for window in trimmed_sample.unfold(
- 0, self.sentinel_token_ids.shape[-1], 1):
- if torch.all(torch.eq(self.sentinel_token_ids, window)):
- return True
+ for i in range(len(self.sentinel_token_ids)):
+ # Can't unfold, output is still too tiny. Skip.
+ if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
+ continue
+ for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
+ if torch.all(torch.eq(self.sentinel_token_ids[i], window)):
+ return True
return False
class Stream(transformers.StoppingCriteria):
diff --git a/modules/chat.py b/modules/chat.py
index 78fc4ab5..b1280d48 100644
--- a/modules/chat.py
+++ b/modules/chat.py
@@ -51,41 +51,31 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
prompt = ''.join(rows)
return prompt
-def extract_message_from_reply(question, reply, name1, name2, check, impersonate=False):
+def extract_message_from_reply(reply, name1, name2, check):
next_character_found = False
- asker = name1 if not impersonate else name2
- replier = name2 if not impersonate else name1
-
- previous_idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(replier)}:", question)]
- idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(replier)}:", reply)]
- idx = idx[max(len(previous_idx)-1, 0)]
-
- if not impersonate:
- reply = reply[idx + 1 + len(apply_extensions(f"{replier}:", "bot_prefix")):]
- else:
- reply = reply[idx + 1 + len(f"{replier}:"):]
-
if check:
lines = reply.split('\n')
reply = lines[0].strip()
if len(lines) > 1:
next_character_found = True
else:
- idx = reply.find(f"\n{asker}:")
- if idx != -1:
- reply = reply[:idx]
- next_character_found = True
- reply = fix_newlines(reply)
+ for string in [f"\n{name1}:", f"\n{name2}:"]:
+ idx = reply.find(string)
+ if idx != -1:
+ reply = reply[:idx]
+ next_character_found = True
# If something like "\nYo" is generated just before "\nYou:"
# is completed, trim it
- next_turn = f"\n{asker}:"
- for j in range(len(next_turn)-1, 0, -1):
- if reply[-j:] == next_turn[:j]:
- reply = reply[:-j]
- break
+ if not next_character_found:
+ for string in [f"\n{name1}:", f"\n{name2}:"]:
+ for j in range(len(string)-1, 0, -1):
+ if reply[-j:] == string[:j]:
+ reply = reply[:-j]
+ break
+ reply = fix_newlines(reply)
return reply, next_character_found
def stop_everything_event():
@@ -127,10 +117,10 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
# Generate
reply = ''
for i in range(chat_generation_attempts):
- for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_string=f"\n{name1}:"):
+ for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
# Extracting the reply
- reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check)
+ reply, next_character_found = extract_message_from_reply(reply, name1, name2, check)
visible_reply = re.sub("(||{{user}})", name1_original, reply)
visible_reply = apply_extensions(visible_reply, "output")
if shared.args.chat:
@@ -166,8 +156,8 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
# Yield *Is typing...*
yield shared.processing_message
for i in range(chat_generation_attempts):
- for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_string=f"\n{name2}:"):
- reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
+ for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
+ reply, next_character_found = extract_message_from_reply(reply, name1, name2, check)
yield reply
if next_character_found:
break
diff --git a/modules/text_generation.py b/modules/text_generation.py
index e738cb21..fd017e2c 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -99,25 +99,37 @@ def set_manual_seed(seed):
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
-def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_string=None):
+def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
clear_torch_cache()
set_manual_seed(seed)
t0 = time.time()
+ original_question = question
+ if not (shared.args.chat or shared.args.cai_chat):
+ question = apply_extensions(question, "input")
+ if shared.args.verbose:
+ print(f"\n\n{question}\n--------------------\n")
+
# These models are not part of Hugging Face, so we handle them
# separately and terminate the function call earlier
if shared.is_RWKV:
try:
if shared.args.no_stream:
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
+ if not (shared.args.chat or shared.args.cai_chat):
+ reply = original_question + apply_extensions(reply, "output")
yield formatted_outputs(reply, shared.model_name)
else:
if not (shared.args.chat or shared.args.cai_chat):
yield formatted_outputs(question, shared.model_name)
+
# RWKV has proper streaming, which is very nice.
# No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
+ if not (shared.args.chat or shared.args.cai_chat):
+ reply = original_question + apply_extensions(reply, "output")
yield formatted_outputs(reply, shared.model_name)
+
except Exception:
traceback.print_exc()
finally:
@@ -127,12 +139,6 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
return
- original_question = question
- if not (shared.args.chat or shared.args.cai_chat):
- question = apply_extensions(question, "input")
- if shared.args.verbose:
- print(f"\n\n{question}\n--------------------\n")
-
input_ids = encode(question, max_new_tokens)
original_input_ids = input_ids
output = input_ids[0]
@@ -142,9 +148,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
if eos_token is not None:
eos_token_ids.append(int(encode(eos_token)[0][-1]))
stopping_criteria_list = transformers.StoppingCriteriaList()
- if stopping_string is not None:
- # Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
- t = encode(stopping_string, 0, add_special_tokens=False)
+ if type(stopping_strings) is list and len(stopping_strings) > 0:
+ t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
generate_params = {}
@@ -195,12 +200,10 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
+ new_tokens = len(output) - len(input_ids[0])
+ reply = decode(output[-new_tokens:])
if not (shared.args.chat or shared.args.cai_chat):
- new_tokens = len(output) - len(input_ids[0])
- reply = decode(output[-new_tokens:])
reply = original_question + apply_extensions(reply, "output")
- else:
- reply = decode(output)
yield formatted_outputs(reply, shared.model_name)
@@ -223,12 +226,11 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
for output in generator:
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
+
+ new_tokens = len(output) - len(input_ids[0])
+ reply = decode(output[-new_tokens:])
if not (shared.args.chat or shared.args.cai_chat):
- new_tokens = len(output) - len(input_ids[0])
- reply = decode(output[-new_tokens:])
reply = original_question + apply_extensions(reply, "output")
- else:
- reply = decode(output)
if output[-1] in eos_token_ids:
break
@@ -244,12 +246,11 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
output = shared.model.generate(**generate_params)[0]
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
+
+ new_tokens = len(output) - len(original_input_ids[0])
+ reply = decode(output[-new_tokens:])
if not (shared.args.chat or shared.args.cai_chat):
- new_tokens = len(output) - len(original_input_ids[0])
- reply = decode(output[-new_tokens:])
reply = original_question + apply_extensions(reply, "output")
- else:
- reply = decode(output)
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break
From bf22d16ebcee96430d6845c9786bbdab5e74af17 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 23 Mar 2023 21:56:26 -0300
Subject: [PATCH 35/46] Clear cache while switching LoRAs
---
modules/LoRA.py | 15 +++++++++------
modules/callbacks.py | 8 +-------
server.py | 14 +++-----------
3 files changed, 13 insertions(+), 24 deletions(-)
diff --git a/modules/LoRA.py b/modules/LoRA.py
index 5f77e340..1c03826b 100644
--- a/modules/LoRA.py
+++ b/modules/LoRA.py
@@ -2,19 +2,22 @@ from pathlib import Path
import modules.shared as shared
from modules.models import load_model
+from modules.text_generation import clear_torch_cache
+def reload_model():
+ shared.model = shared.tokenizer = None
+ clear_torch_cache()
+ shared.model, shared.tokenizer = load_model(shared.model_name)
+
def add_lora_to_model(lora_name):
from peft import PeftModel
- # Is there a more efficient way of returning to the base model?
- if lora_name == "None":
- print("Reloading the model to remove the LoRA...")
- shared.model, shared.tokenizer = load_model(shared.model_name)
- else:
+ reload_model()
+
+ if lora_name != "None":
print(f"Adding the LoRA {lora_name} to the model...")
-
params = {}
if not shared.args.cpu:
params['dtype'] = shared.model.dtype
diff --git a/modules/callbacks.py b/modules/callbacks.py
index 2ae9d908..50a69183 100644
--- a/modules/callbacks.py
+++ b/modules/callbacks.py
@@ -1,11 +1,10 @@
-import gc
from queue import Queue
from threading import Thread
import torch
import transformers
-import modules.shared as shared
+from modules.text_generation import clear_torch_cache
# Copied from https://github.com/PygmalionAI/gradio-ui/
@@ -90,8 +89,3 @@ class Iteratorize:
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True
clear_torch_cache()
-
-def clear_torch_cache():
- gc.collect()
- if not shared.args.cpu:
- torch.cuda.empty_cache()
diff --git a/server.py b/server.py
index cdf7aa93..068f380a 100644
--- a/server.py
+++ b/server.py
@@ -1,4 +1,3 @@
-import gc
import io
import json
import re
@@ -8,7 +7,6 @@ import zipfile
from pathlib import Path
import gradio as gr
-import torch
import modules.chat as chat
import modules.extensions as extensions_module
@@ -17,7 +15,7 @@ import modules.ui as ui
from modules.html_generator import generate_chat_html
from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt
-from modules.text_generation import generate_reply
+from modules.text_generation import clear_torch_cache, generate_reply
# Loading custom settings
settings_file = None
@@ -56,21 +54,15 @@ def load_model_wrapper(selected_model):
if selected_model != shared.model_name:
shared.model_name = selected_model
shared.model = shared.tokenizer = None
- if not shared.args.cpu:
- gc.collect()
- torch.cuda.empty_cache()
+ clear_torch_cache()
shared.model, shared.tokenizer = load_model(shared.model_name)
return selected_model
def load_lora_wrapper(selected_lora):
shared.lora_name = selected_lora
- default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
-
- if not shared.args.cpu:
- gc.collect()
- torch.cuda.empty_cache()
add_lora_to_model(selected_lora)
+ default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
return selected_lora, default_text
From b0abb327d822f8fe4c0180a4a725c0e362182b8f Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 23 Mar 2023 22:02:09 -0300
Subject: [PATCH 36/46] Update LoRA.py
---
modules/LoRA.py | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/modules/LoRA.py b/modules/LoRA.py
index 1c03826b..aa68ad32 100644
--- a/modules/LoRA.py
+++ b/modules/LoRA.py
@@ -14,7 +14,11 @@ def add_lora_to_model(lora_name):
from peft import PeftModel
- reload_model()
+ # If a LoRA had been previously loaded, or if we want
+ # to unload a LoRA, reload the model
+ if shared.lora_name != "None" or lora_name == "None":
+ reload_model()
+ shared.lora_name = lora_name
if lora_name != "None":
print(f"Adding the LoRA {lora_name} to the model...")
From 9bdb3c784d07b4f81f8dc39a97796d231bd89bff Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 23 Mar 2023 22:02:40 -0300
Subject: [PATCH 37/46] Minor fix
---
server.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/server.py b/server.py
index 068f380a..435b8525 100644
--- a/server.py
+++ b/server.py
@@ -60,7 +60,6 @@ def load_model_wrapper(selected_model):
return selected_model
def load_lora_wrapper(selected_lora):
- shared.lora_name = selected_lora
add_lora_to_model(selected_lora)
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
From d1327f99f915aca83abac739107cdb8c5d29d278 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 23 Mar 2023 22:12:24 -0300
Subject: [PATCH 38/46] Fix broken callbacks.py
---
modules/callbacks.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/modules/callbacks.py b/modules/callbacks.py
index 50a69183..93cd1d63 100644
--- a/modules/callbacks.py
+++ b/modules/callbacks.py
@@ -4,8 +4,6 @@ from threading import Thread
import torch
import transformers
-from modules.text_generation import clear_torch_cache
-
# Copied from https://github.com/PygmalionAI/gradio-ui/
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
@@ -89,3 +87,8 @@ class Iteratorize:
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True
clear_torch_cache()
+
+def clear_torch_cache():
+ gc.collect()
+ if not shared.args.cpu:
+ torch.cuda.empty_cache()
From 7078d168c31084255a99e1b4fd879e9a8a353a0d Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 23 Mar 2023 22:16:08 -0300
Subject: [PATCH 39/46] Missing import
---
modules/callbacks.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/modules/callbacks.py b/modules/callbacks.py
index 93cd1d63..40811408 100644
--- a/modules/callbacks.py
+++ b/modules/callbacks.py
@@ -1,3 +1,4 @@
+import gc
from queue import Queue
from threading import Thread
From 8747c74339cf1e7f1d45f4aa1dcc090e9eba94a3 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 23 Mar 2023 22:19:01 -0300
Subject: [PATCH 40/46] Another missing import
---
modules/callbacks.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/modules/callbacks.py b/modules/callbacks.py
index 40811408..2ae9d908 100644
--- a/modules/callbacks.py
+++ b/modules/callbacks.py
@@ -5,6 +5,8 @@ from threading import Thread
import torch
import transformers
+import modules.shared as shared
+
# Copied from https://github.com/PygmalionAI/gradio-ui/
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
From 143b5b5edf5d47539496598dbdb6cfe4843c169a Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 23 Mar 2023 23:28:50 -0300
Subject: [PATCH 41/46] Mention one-click-bandaid in the README
---
README.md | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/README.md b/README.md
index cb070445..85dcc270 100644
--- a/README.md
+++ b/README.md
@@ -101,6 +101,10 @@ Just download the zip above, extract it, and double click on "install". The web
Source codes: https://github.com/oobabooga/one-click-installers
+> **Note**
+>
+> To get 8-bit and 4-bit models working in your 1-click Windows installation, you can use the [one-click-bandaid](https://github.com/ClayShoaf/oobabooga-one-click-bandaid).
+
This method lags behind the newest developments and does not support 8-bit mode on Windows without additional set up: https://github.com/oobabooga/text-generation-webui/issues/147#issuecomment-1456040134, https://github.com/oobabooga/text-generation-webui/issues/20#issuecomment-1411650652
### Alternative: Docker
From bb4cb2245373acb950e1c8dbaa73caf75920723d Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Fri, 24 Mar 2023 00:49:04 -0300
Subject: [PATCH 42/46] Download .pt files using download-model.py (for 4-bit
models)
---
download-model.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/download-model.py b/download-model.py
index 7c2965f6..7ca33b7d 100644
--- a/download-model.py
+++ b/download-model.py
@@ -116,10 +116,11 @@ def get_download_links_from_huggingface(model, branch):
is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
is_safetensors = re.match("model.*\.safetensors", fname)
+ is_pt = re.match(".*\.pt", fname)
is_tokenizer = re.match("tokenizer.*\.model", fname)
is_text = re.match(".*\.(txt|json|py)", fname) or is_tokenizer
- if any((is_pytorch, is_safetensors, is_text, is_tokenizer)):
+ if any((is_pytorch, is_safetensors, is_pt, is_tokenizer, is_text)):
if is_text:
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
classifications.append('text')
@@ -132,7 +133,8 @@ def get_download_links_from_huggingface(model, branch):
elif is_pytorch:
has_pytorch = True
classifications.append('pytorch')
-
+ elif is_pt:
+ classifications.append('pt')
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
cursor = base64.b64encode(cursor)
From 04417b658b53207c805851145c96bc1ce903937b Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Fri, 24 Mar 2023 01:40:43 -0300
Subject: [PATCH 43/46] Update README.md
---
README.md | 8 +++-----
1 file changed, 3 insertions(+), 5 deletions(-)
diff --git a/README.md b/README.md
index 85dcc270..4e4959ac 100644
--- a/README.md
+++ b/README.md
@@ -84,10 +84,6 @@ pip install -r requirements.txt
>
> For bitsandbytes and `--load-in-8bit` to work on Linux/WSL, this dirty fix is currently necessary: https://github.com/oobabooga/text-generation-webui/issues/400#issuecomment-1474876859
-### Alternative: native Windows installation
-
-As an alternative to the recommended WSL method, you can install the web UI natively on Windows using this guide. It will be a lot harder and the performance may be slower: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings).
-
### Alternative: one-click installers
[oobabooga-windows.zip](https://github.com/oobabooga/one-click-installers/archive/refs/heads/oobabooga-windows.zip)
@@ -105,7 +101,9 @@ Source codes: https://github.com/oobabooga/one-click-installers
>
> To get 8-bit and 4-bit models working in your 1-click Windows installation, you can use the [one-click-bandaid](https://github.com/ClayShoaf/oobabooga-one-click-bandaid).
-This method lags behind the newest developments and does not support 8-bit mode on Windows without additional set up: https://github.com/oobabooga/text-generation-webui/issues/147#issuecomment-1456040134, https://github.com/oobabooga/text-generation-webui/issues/20#issuecomment-1411650652
+### Alternative: native Windows installation
+
+As an alternative to the recommended WSL method, you can install the web UI natively on Windows using this guide. It will be a lot harder and the performance may be slower: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings).
### Alternative: Docker
From 4f5c2ce78560689dc8ed08a3cbb33ef15a3b4a95 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Fri, 24 Mar 2023 02:03:30 -0300
Subject: [PATCH 44/46] Fix chat_generation_attempts
---
modules/chat.py | 18 +++++++++++++-----
1 file changed, 13 insertions(+), 5 deletions(-)
diff --git a/modules/chat.py b/modules/chat.py
index b1280d48..061177d2 100644
--- a/modules/chat.py
+++ b/modules/chat.py
@@ -115,9 +115,10 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
yield shared.history['visible']+[[visible_text, shared.processing_message]]
# Generate
- reply = ''
+ cumulative_reply = ''
for i in range(chat_generation_attempts):
- for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
+ for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
+ reply = cumulative_reply + reply
# Extracting the reply
reply, next_character_found = extract_message_from_reply(reply, name1, name2, check)
@@ -142,6 +143,8 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
if next_character_found:
break
+ cumulative_reply = reply
+
yield shared.history['visible']
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
@@ -152,16 +155,21 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
- reply = ''
# Yield *Is typing...*
yield shared.processing_message
+
+ cumulative_reply = ''
for i in range(chat_generation_attempts):
- for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
+ for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
+ reply = cumulative_reply + reply
reply, next_character_found = extract_message_from_reply(reply, name1, name2, check)
yield reply
if next_character_found:
break
- yield reply
+
+ cumulative_reply = reply
+
+ yield reply
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
From fd99995b01878246b62302d31a844dd68ee7d139 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Fri, 24 Mar 2023 15:59:27 -0300
Subject: [PATCH 45/46] Make the Stop button more consistent in chat mode
---
server.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/server.py b/server.py
index 435b8525..7b25e91d 100644
--- a/server.py
+++ b/server.py
@@ -329,7 +329,7 @@ def create_interface():
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
- shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
+ shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events, queue=False)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
From d8e950d6bdf933f8a0cd78a0c7cb2a941b8d32e3 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Fri, 24 Mar 2023 16:30:32 -0300
Subject: [PATCH 46/46] Don't load the model twice when using --lora
---
server.py | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/server.py b/server.py
index 7b25e91d..f423e368 100644
--- a/server.py
+++ b/server.py
@@ -233,9 +233,7 @@ else:
shared.model_name = available_models[i]
shared.model, shared.tokenizer = load_model(shared.model_name)
if shared.args.lora:
- print(shared.args.lora)
- shared.lora_name = shared.args.lora
- add_lora_to_model(shared.lora_name)
+ add_lora_to_model(shared.args.lora)
# Default UI settings
default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]