mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Fix Continue for LLaVA (#1507)
This commit is contained in:
parent
12212cf6be
commit
04b98a8485
@ -91,7 +91,7 @@ class LLaVAEmbedder:
|
||||
# replace the image token with the image patch token in the prompt (each occurrence)
|
||||
replace_token = LLaVAEmbedder.IM_PATCH.token * 256
|
||||
replace_token = LLaVAEmbedder.IM_START.token + replace_token + LLaVAEmbedder.IM_END.token
|
||||
prompt = re.sub(r"<image:([A-Za-z0-9+/=]+)>", replace_token, prompt, 1)
|
||||
prompt = re.sub(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', replace_token, prompt, 1)
|
||||
return prompt
|
||||
|
||||
def _extract_image_features(self, images):
|
||||
@ -146,11 +146,11 @@ class LLaVAEmbedder:
|
||||
|
||||
@staticmethod
|
||||
def len_in_tokens(text):
|
||||
images = re.findall(r"<image:[A-Za-z0-9+/=]+>", text)
|
||||
images = re.findall(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', text)
|
||||
image_tokens = 0
|
||||
for _ in images:
|
||||
image_tokens += 258
|
||||
return len(encode(re.sub(r"<image:[A-Za-z0-9+/=]+>", '', text))[0]) + image_tokens
|
||||
return len(encode(re.sub(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', '', text))[0]) + image_tokens
|
||||
|
||||
|
||||
def add_chat_picture(picture, text, visible_text):
|
||||
@ -166,32 +166,21 @@ def add_chat_picture(picture, text, visible_text):
|
||||
buffer = BytesIO()
|
||||
picture.save(buffer, format="JPEG")
|
||||
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
visible = f'<img src="data:image/jpeg;base64,{img_str}">'
|
||||
internal = f'<image:{img_str}>'
|
||||
image = f'<img src="data:image/jpeg;base64,{img_str}">'
|
||||
|
||||
|
||||
if '<image>' in text:
|
||||
text = text.replace('<image>', image)
|
||||
else:
|
||||
text = text + '\n' + image
|
||||
|
||||
if visible_text == '' or visible_text is None:
|
||||
visible_text = text
|
||||
|
||||
if '<image>' in text:
|
||||
text = text.replace('<image>', internal)
|
||||
elif '<image>' in visible_text:
|
||||
visible_text = visible_text.replace('<image>', image)
|
||||
else:
|
||||
text = text + '\n' + internal
|
||||
visible_text = visible_text + '\n' + image
|
||||
|
||||
if '<image>' in visible_text:
|
||||
visible_text = visible_text.replace('<image>', visible)
|
||||
else:
|
||||
visible_text = visible_text + '\n' + visible
|
||||
|
||||
return text, visible_text
|
||||
|
||||
|
||||
def fix_picture_after_remove_last(text, visible_text):
|
||||
image = re.search(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', text)
|
||||
if image is None:
|
||||
return text, visible_text
|
||||
if visible_text is None:
|
||||
visible_text = text
|
||||
text = re.sub(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', "<image:\\1>", text)
|
||||
return text, visible_text
|
||||
|
||||
|
||||
@ -248,7 +237,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
|
||||
global params
|
||||
start_ts = time.time()
|
||||
image_matches = re.finditer(r"<image:([A-Za-z0-9+/=]+)>", prompt)
|
||||
image_matches = re.finditer(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', prompt)
|
||||
images = [Image.open(BytesIO(base64.b64decode(match.group(1)))) for match in image_matches]
|
||||
|
||||
if len(images) == 0:
|
||||
@ -276,4 +265,3 @@ def ui():
|
||||
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
|
||||
shared.gradio['Generate'].click(lambda: None, None, picture_select)
|
||||
shared.gradio['textbox'].submit(lambda: None, None, picture_select)
|
||||
shared.gradio['Remove last'].click(lambda: input_hijack.update({"state": True, "value": fix_picture_after_remove_last}), None, None)
|
||||
|
Loading…
Reference in New Issue
Block a user