Merge pull request #489 from Brawlence/ext-fixes

Extensions performance & memory optimisations
This commit is contained in:
oobabooga 2023-03-22 16:10:59 -03:00 committed by GitHub
commit d5fc1bead7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 29 additions and 41 deletions

View File

@ -1,11 +1,11 @@
import re
from pathlib import Path
import gradio as gr
import modules.shared as shared
from elevenlabslib import ElevenLabsUser
from elevenlabslib.helpers import save_bytes_to_path
import modules.shared as shared
params = {
'activate': True,
'api_key': '12345',
@ -52,14 +52,9 @@ def refresh_voices():
return
def remove_surrounded_chars(string):
new_string = ""
in_star = False
for char in string:
if char == '*':
in_star = not in_star
elif not in_star:
new_string += char
return new_string
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)','',string)
def input_modifier(string):
"""
@ -115,4 +110,4 @@ def ui():
voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
api_key.change(lambda x: params.update({'api_key': x}), api_key, None)
connect.click(check_valid_api, [], connection_status)
connect.click(refresh_voices, [], voice)
connect.click(refresh_voices, [], voice)

View File

@ -1,15 +1,15 @@
import base64
import io
import re
from pathlib import Path
import gradio as gr
import modules.chat as chat
import modules.shared as shared
import requests
import torch
from PIL import Image
import modules.chat as chat
import modules.shared as shared
torch._C._jit_set_profiling_mode(False)
# parameters which can be customized in settings.json of webui
@ -31,14 +31,9 @@ picture_response = False # specifies if the next model response should appear as
pic_id = 0
def remove_surrounded_chars(string):
new_string = ""
in_star = False
for char in string:
if char == '*':
in_star = not in_star
elif not in_star:
new_string += char
return new_string
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)','',string)
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
def input_modifier(string):
@ -54,6 +49,8 @@ def input_modifier(string):
mediums = ['image', 'pic', 'picture', 'photo']
subjects = ['yourself', 'own']
lowstr = string.lower()
# TODO: refactor out to separate handler and also replace detection with a regexp
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
picture_response = True
shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
@ -91,9 +88,8 @@ def get_SD_pictures(description):
output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
image.save(output_file.as_posix())
pic_id += 1
# lower the resolution of received images for the chat, otherwise the history size gets out of control quickly with all the base64 values
newsize = (300, 300)
image = image.resize(newsize, Image.LANCZOS)
# lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
image.thumbnail((300, 300))
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
buffered.seek(0)
@ -180,4 +176,4 @@ def ui():
force_btn.click(force_pic)
generate_now_btn.click(force_pic)
generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)

View File

@ -2,11 +2,11 @@ import base64
from io import BytesIO
import gradio as gr
import torch
from transformers import BlipForConditionalGeneration, BlipProcessor
import modules.chat as chat
import modules.shared as shared
import torch
from PIL import Image
from transformers import BlipForConditionalGeneration, BlipProcessor
# If 'state' is True, will hijack the next chat generation with
# custom input text given by 'value' in the format [text, visible_text]
@ -25,10 +25,12 @@ def caption_image(raw_image):
def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
picture.thumbnail((300, 300))
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
visible_text = f'<img src="data:image/jpeg;base64,{img_str}">'
visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
return text, visible_text
def ui():

View File

@ -1,11 +1,11 @@
import re
import time
from pathlib import Path
import gradio as gr
import torch
import modules.chat as chat
import modules.shared as shared
import torch
torch._C._jit_set_profiling_mode(False)
@ -46,14 +46,9 @@ def load_model():
model = load_model()
def remove_surrounded_chars(string):
new_string = ""
in_star = False
for char in string:
if char == '*':
in_star = not in_star
elif not in_star:
new_string += char
return new_string
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)','',string)
def remove_tts_from_history(name1, name2):
for i, entry in enumerate(shared.history['internal']):
@ -166,4 +161,4 @@ def ui():
autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
voice.change(lambda x: params.update({"speaker": x}), voice, None)
v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)
v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)