2023-05-09 19:18:02 -04:00
|
|
|
import base64
|
|
|
|
import re
|
|
|
|
import time
|
|
|
|
from functools import partial
|
|
|
|
from io import BytesIO
|
|
|
|
|
|
|
|
import gradio as gr
|
|
|
|
import torch
|
2023-05-09 21:49:39 -04:00
|
|
|
|
2023-05-09 19:18:02 -04:00
|
|
|
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
|
|
|
|
from modules import shared
|
2023-05-21 21:42:34 -04:00
|
|
|
from modules.logging_colors import logger
|
2023-05-09 19:18:02 -04:00
|
|
|
|
|
|
|
params = {
|
|
|
|
"add_all_images_to_prompt": False,
|
|
|
|
# device to run vision encoder on
|
|
|
|
"vision_device": None,
|
|
|
|
# bits to load vision encoder in, either 16 or 32
|
|
|
|
"vision_bits": 32,
|
|
|
|
# device to run multimodal projector on
|
|
|
|
"projector_device": None,
|
|
|
|
# multimodal projector bits, either 32 or 16
|
|
|
|
"projector_bits": 32
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# If 'state' is True, will hijack the next chat generation
|
|
|
|
input_hijack = {
|
|
|
|
'state': False,
|
|
|
|
'value': ["", ""]
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# initialized in ui, so that params are loaded from settings
|
|
|
|
multimodal_embedder: MultimodalEmbedder = None
|
|
|
|
|
|
|
|
|
2023-07-25 17:49:56 -04:00
|
|
|
def chat_input_modifier(text, visible_text, state):
|
|
|
|
global input_hijack
|
|
|
|
if input_hijack['state']:
|
|
|
|
input_hijack['state'] = False
|
|
|
|
return input_hijack['value'](text, visible_text)
|
|
|
|
else:
|
|
|
|
return text, visible_text
|
|
|
|
|
|
|
|
|
2023-05-09 19:18:02 -04:00
|
|
|
def add_chat_picture(picture, text, visible_text):
|
|
|
|
# resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
|
2023-10-20 01:28:14 -04:00
|
|
|
# Adjusted to 336 for the values here, due to the increased resolution in llava-v1.5
|
2023-05-09 19:18:02 -04:00
|
|
|
max_hw, min_hw = max(picture.size), min(picture.size)
|
|
|
|
aspect_ratio = max_hw / min_hw
|
2023-10-20 01:28:14 -04:00
|
|
|
shortest_edge = int(max(336 / aspect_ratio, 336))
|
2023-05-09 19:18:02 -04:00
|
|
|
longest_edge = int(shortest_edge * aspect_ratio)
|
|
|
|
w = shortest_edge if picture.width < picture.height else longest_edge
|
|
|
|
h = shortest_edge if picture.width >= picture.height else longest_edge
|
2023-05-09 19:20:35 -04:00
|
|
|
picture = picture.resize((w, h))
|
2023-05-09 19:18:02 -04:00
|
|
|
|
|
|
|
buffer = BytesIO()
|
2023-10-20 01:28:14 -04:00
|
|
|
picture.save(buffer, format="PNG")
|
2023-05-09 19:18:02 -04:00
|
|
|
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
|
|
|
image = f'<img src="data:image/jpeg;base64,{img_str}">'
|
|
|
|
|
|
|
|
if '<image>' in text:
|
|
|
|
text = text.replace('<image>', image)
|
|
|
|
else:
|
2023-10-20 01:28:14 -04:00
|
|
|
text = image + '\n' + text
|
2023-05-09 19:18:02 -04:00
|
|
|
|
|
|
|
if visible_text == '' or visible_text is None:
|
|
|
|
visible_text = text
|
|
|
|
elif '<image>' in visible_text:
|
|
|
|
visible_text = visible_text.replace('<image>', image)
|
|
|
|
else:
|
|
|
|
visible_text = visible_text + '\n' + image
|
|
|
|
|
|
|
|
return text, visible_text
|
|
|
|
|
|
|
|
|
|
|
|
def custom_tokenized_length(prompt):
|
|
|
|
return multimodal_embedder.len_in_tokens(prompt)
|
|
|
|
|
|
|
|
|
|
|
|
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
|
|
|
|
global params
|
|
|
|
start_ts = time.time()
|
|
|
|
image_match = re.search(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', prompt)
|
|
|
|
|
|
|
|
if image_match is None:
|
|
|
|
return prompt, input_ids, input_embeds
|
|
|
|
|
|
|
|
prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params)
|
2023-05-21 21:42:34 -04:00
|
|
|
logger.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
|
2023-05-09 19:18:02 -04:00
|
|
|
return (prompt,
|
2023-05-09 19:20:35 -04:00
|
|
|
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
|
|
|
|
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))
|
2023-05-09 19:18:02 -04:00
|
|
|
|
|
|
|
|
|
|
|
def ui():
|
|
|
|
global multimodal_embedder
|
|
|
|
multimodal_embedder = MultimodalEmbedder(params)
|
|
|
|
with gr.Column():
|
|
|
|
picture_select = gr.Image(label='Send a picture', type='pil')
|
|
|
|
# The models don't seem to deal well with multiple images
|
|
|
|
single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one')
|
|
|
|
# Prepare the input hijack
|
|
|
|
picture_select.upload(
|
|
|
|
lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}),
|
|
|
|
[picture_select],
|
|
|
|
None
|
|
|
|
)
|
2023-05-09 19:20:35 -04:00
|
|
|
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["", ""]}), None, None)
|
2023-05-09 19:18:02 -04:00
|
|
|
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
|
|
|
|
shared.gradio['Generate'].click(lambda: None, None, picture_select)
|
|
|
|
shared.gradio['textbox'].submit(lambda: None, None, picture_select)
|