mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
LLaVA support (#1487)
This commit is contained in:
parent
9197d3fec8
commit
12212cf6be
3
characters/instruction-following/LLaVA.yaml
Normal file
3
characters/instruction-following/LLaVA.yaml
Normal file
@ -0,0 +1,3 @@
|
||||
name: "### Assistant"
|
||||
your_name: "### Human"
|
||||
context: "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.\n### Human: \nHi!\n### Assistant: \nHi there! How can I help you today?\n"
|
@ -33,6 +33,7 @@ Most of these have been created by the extremely talented contributors that you
|
||||
|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. |
|
||||
|[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. |
|
||||
|[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). |
|
||||
|[llava](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava) | Adds LLaVA multimodal model support. For detailed description see [README.md](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava/README.md) in the extension directory. |
|
||||
|
||||
## How to write an extension
|
||||
|
||||
@ -45,6 +46,7 @@ Most of these have been created by the extremely talented contributors that you
|
||||
| `def output_modifier(string)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. |
|
||||
| `def bot_prefix_modifier(string)` | Applied in chat mode to the prefix for the bot's reply (more on that below). |
|
||||
| `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. |
|
||||
| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See `llava` extension for an example |
|
||||
|
||||
Additionally, the script may define two special global variables:
|
||||
|
||||
@ -70,7 +72,9 @@ input_hijack = {
|
||||
'value': ["", ""]
|
||||
}
|
||||
```
|
||||
This is only relevant in chat mode. If your extension sets `input_hijack['state']` to `True` at any moment, the next call to `modules.chat.chatbot_wrapper` will use the vales inside `input_hijack['value']` as the user input for text generation. See the `send_pictures` extension above for an example.
|
||||
This is only relevant in chat mode. If your extension sets `input_hijack['state']` to `True` at any moment, the next call to `modules.chat.chatbot_wrapper` will use the values inside `input_hijack['value']` as the user input for text generation. See the `send_pictures` extension above for an example.
|
||||
|
||||
Additionally, your extension can set the value to be a callback, in the form of `def cb(text: str, visible_text: str) -> [str, str]`. See the `llava` extension above for an example.
|
||||
|
||||
## The `bot_prefix_modifier`
|
||||
|
||||
|
49
extensions/llava/README.md
Normal file
49
extensions/llava/README.md
Normal file
@ -0,0 +1,49 @@
|
||||
# LLaVA
|
||||
|
||||
## Description
|
||||
Adds [LLaVA](https://github.com/haotian-liu/LLaVA) multimodality support to text-generation-webui.
|
||||
|
||||
https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b742-3204bb13b8fc.mp4
|
||||
|
||||
## Usage
|
||||
To run this extension, download LLaVA weights, for example from [here](https://huggingface.co/wojtab/llava-13b-v0-4bit-128g), and then start server.py with `--extensions llava` argument.
|
||||
|
||||
When in ui, go to instruct mode, and select LLaVA template, you also should add `"\n###"` to "Custom stopping strings" in parameters tab.
|
||||
|
||||
Do note, that each image takes up 258 tokens, so adjust max_new_tokens to be at most 1700 (recommended value is between 200 to 500), so the images don't get truncated.
|
||||
|
||||
To send an image, just upload it to the extension field below chat, and send a prompt as always. The image will be added to the end of your message. If you wish to modify the placement, include a string `<image>` in your prompt.
|
||||
|
||||
Additionally, there is *Embed all images, not only the last one* checkbox. It modifies the image embeddings, by default (if it's unchecked), all but the most recent images have their embeddings empty, so they are not fed to the network. From initial testing, it seems as LLaVA considers the features in all images at the same time, so by default the extension skips previous images. If you want to include them anyway, just tick this checkbox.
|
||||
|
||||
## Extension config
|
||||
This extension uses following parameters (from settings.json):
|
||||
|Parameter|Description|
|
||||
|---------|-----------|
|
||||
|`llava-clip_bits`|Number of bits to load CLIP feature extractor in (either 32 or 16, default=32)|
|
||||
|`llava-clip_device`|Torch device to run the extractor on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|
||||
|`llava-projector_bits`|Number of bits to load CLIP->LLaMA feature projector in (either 32 or 16, default=32)|
|
||||
|`llava-projector_bits`|Torch device to run the CLIP->LLaMA feature projector on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|
||||
|`llava-add_all_images_to_prompt`|Default value of "Embed all images, not only the last one" checkbox|
|
||||
|
||||
## Technical description
|
||||
|
||||
### Original LLaVA
|
||||
The default LLaVA implementation uses modified `transformers` library, however this extension forgoes this requirement. The transformers are modified in LLaVA in such a way, that the entire LLaVA model gets loaded, and the inference now looks as follows:
|
||||
```
|
||||
images --> CLIP --> projector --> input embeddings for images --> |
|
||||
| --> LLaMA
|
||||
prompt -------------------------> input embeddings for text ----> |
|
||||
```
|
||||
The images are represented in the prompt by the following token IDs:
|
||||
- 32000 - `<im_patch>` - placeholder token for embeddings from projector
|
||||
- 32001 - `<im_start>` - token marking start of an image
|
||||
- 32002 - `<im_end>` - token marking end of an image
|
||||
|
||||
By default, image will be represented as `<im_start><im_patch>*256<im_end>`. The input embeddings for an image are converted with a single linear layer of the projector, then they are placed instead of `<im_patch>` tokens.
|
||||
The concatenated prompt then gets fed to fine-tuned LLaMA.
|
||||
|
||||
### In this extension
|
||||
|
||||
Using default transformers, they only load the LLaMA part of LLaVA, ignoring the added projector weights, and not loading CLIP. We then reconstruct the `images -> CLIP -> projector` pipeline ourselves, then concatenate the input embeddings, and feed it to LLaMA loaded by transformers. This allows us to use normal flow from webui to load this model, and just hijack the model input with additional features.
|
||||
Splitting it to 3 separate models, allows us to configure each of them, and to move them to different devices(for example we can run CLIP+projector on CPU and LLaMA on GPU). Also, it enables us to use 4-bit GPTQ quantization for LLaVA, massively cutting down the VRAM requirement (it should be possible to fit on 12GB of VRAM with full context size by moving CLIP and projector to CPU).
|
279
extensions/llava/script.py
Normal file
279
extensions/llava/script.py
Normal file
@ -0,0 +1,279 @@
|
||||
import base64
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModel
|
||||
|
||||
from modules import shared
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.text_generation import encode, get_max_prompt_length
|
||||
|
||||
params = {
|
||||
"add_all_images_to_prompt": False,
|
||||
# device to run CLIP on
|
||||
"clip_device": None,
|
||||
# bits to load clip in either 32 or 16 (it doesn't support 8-bit)
|
||||
"clip_bits": 32,
|
||||
# device to run projector on
|
||||
"projector_device": None,
|
||||
# projector bits, either 32 or 16
|
||||
"projector_bits": 32
|
||||
}
|
||||
|
||||
|
||||
# If 'state' is True, will hijack the next chat generation
|
||||
input_hijack = {
|
||||
'state': False,
|
||||
'value': ["", ""]
|
||||
}
|
||||
|
||||
|
||||
# initialized in ui, so that params are loaded from settings
|
||||
llava_embedder = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Token:
|
||||
token: str
|
||||
id: int
|
||||
|
||||
|
||||
class LLaVAEmbedder:
|
||||
IM_PATCH = Token("<im_patch>", 32000)
|
||||
IM_START = Token("<im_start>", 32001)
|
||||
IM_END = Token("<im_end>", 32002)
|
||||
CLIP_VIT_HUB_NAME = 'openai/clip-vit-large-patch14'
|
||||
PROJECTOR_HUB_NAME = 'liuhaotian/LLaVA-13b-pretrain-projector-v0'
|
||||
PROJECTOR_FILE = 'LLaVA-13b-pretrain-projector-v0-CC3M-595K-original_caption.bin'
|
||||
|
||||
def __init__(self):
|
||||
self.clip_device = self._get_device("clip_device")
|
||||
self.clip_dtype = self._get_dtype("clip_bits")
|
||||
self.projector_device = self._get_device("projector_device")
|
||||
self.projector_dtype = self._get_dtype("projector_bits")
|
||||
self.image_processor, self.vision_tower, self.mm_projector = self._load_models()
|
||||
|
||||
def _get_device(self, setting_name):
|
||||
if params[setting_name] is None:
|
||||
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
return torch.device(params[setting_name])
|
||||
|
||||
def _get_dtype(self, setting_name):
|
||||
return torch.float32 if int(params[setting_name]) == 32 else torch.float16
|
||||
|
||||
def _load_models(self):
|
||||
start_ts = time.time()
|
||||
|
||||
print(f"LLaVA - Loading {LLaVAEmbedder.CLIP_VIT_HUB_NAME} as {self.clip_dtype} on {self.clip_device}...")
|
||||
image_processor = CLIPImageProcessor.from_pretrained(LLaVAEmbedder.CLIP_VIT_HUB_NAME, torch_dtype=self.clip_dtype)
|
||||
vision_tower = CLIPVisionModel.from_pretrained(LLaVAEmbedder.CLIP_VIT_HUB_NAME, torch_dtype=self.clip_dtype).to(self.clip_device)
|
||||
|
||||
print(f"LLaVA - Loading {LLaVAEmbedder.PROJECTOR_HUB_NAME} as {self.projector_dtype} on {self.projector_device}...")
|
||||
projector_path = hf_hub_download(LLaVAEmbedder.PROJECTOR_HUB_NAME, LLaVAEmbedder.PROJECTOR_FILE)
|
||||
mm_projector = torch.nn.Linear(1024, 5120)
|
||||
projector_data = torch.load(projector_path)
|
||||
mm_projector.weight = torch.nn.Parameter(projector_data['model.mm_projector.weight'].to(dtype=self.projector_dtype), False)
|
||||
mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False)
|
||||
mm_projector = mm_projector.to(self.projector_device)
|
||||
|
||||
print(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
|
||||
return image_processor, vision_tower, mm_projector
|
||||
|
||||
def _update_prompt(self, prompt, images):
|
||||
for _ in images:
|
||||
# replace the image token with the image patch token in the prompt (each occurrence)
|
||||
replace_token = LLaVAEmbedder.IM_PATCH.token * 256
|
||||
replace_token = LLaVAEmbedder.IM_START.token + replace_token + LLaVAEmbedder.IM_END.token
|
||||
prompt = re.sub(r"<image:([A-Za-z0-9+/=]+)>", replace_token, prompt, 1)
|
||||
return prompt
|
||||
|
||||
def _extract_image_features(self, images):
|
||||
images = self.image_processor(images, return_tensors='pt')['pixel_values']
|
||||
images = images.to(self.clip_device, dtype=self.clip_dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
image_forward_outs = self.vision_tower(images, output_hidden_states=True)
|
||||
select_hidden_state_layer = -2
|
||||
select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
|
||||
image_features = select_hidden_state[:, 1:].to(self.projector_device, dtype=self.projector_dtype)
|
||||
image_features = self.mm_projector(image_features)
|
||||
return image_features
|
||||
|
||||
def forward(self, prompt, images, state):
|
||||
prompt = self._update_prompt(prompt, images)
|
||||
input_ids = encode(prompt, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))[0]
|
||||
input_embeds = shared.model.model.embed_tokens(input_ids).to(self.projector_device)
|
||||
|
||||
if input_ids[0] == LLaVAEmbedder.IM_PATCH.id:
|
||||
# prompt got truncated in the middle of an image, remove the image data
|
||||
im_end = torch.where(input_ids == LLaVAEmbedder.IM_END.id)[0][0]
|
||||
input_ids = input_ids[im_end+1:]
|
||||
input_embeds = input_embeds[im_end+1:]
|
||||
leftover_images = torch.where(input_ids == LLaVAEmbedder.IM_START.id)[0].shape[0]
|
||||
print(f"LLaVA - WARNING: removed {len(images) - leftover_images} image(s) from prompt. The generation might be broken, try decreasing max_new_tokens")
|
||||
images = images[-leftover_images:]
|
||||
if len(images) == 0:
|
||||
return prompt, input_ids, input_embeds, 0
|
||||
|
||||
total_embedded = 0
|
||||
image_features = self._extract_image_features(images).to(self.projector_device)
|
||||
image_start_tokens = torch.where(input_ids == LLaVAEmbedder.IM_START.id)[0]
|
||||
|
||||
if not torch.any(input_ids == LLaVAEmbedder.IM_PATCH.id) or len(image_start_tokens) == 0:
|
||||
# multimodal LLM, but the current prompt is not multimodal/truncated
|
||||
return prompt, input_ids, input_embeds, total_embedded
|
||||
|
||||
cur_image_idx = 0
|
||||
if not params['add_all_images_to_prompt']:
|
||||
image_start_tokens = [image_start_tokens[-1]]
|
||||
cur_image_idx = -1
|
||||
|
||||
for image_start_token_pos in image_start_tokens:
|
||||
cur_image_features = image_features[cur_image_idx]
|
||||
num_patches = cur_image_features.shape[0]
|
||||
input_embeds = torch.cat((input_embeds[:image_start_token_pos+1], cur_image_features, input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
|
||||
cur_image_idx += 1
|
||||
total_embedded += 1
|
||||
|
||||
return prompt, input_ids, input_embeds, total_embedded
|
||||
|
||||
@staticmethod
|
||||
def len_in_tokens(text):
|
||||
images = re.findall(r"<image:[A-Za-z0-9+/=]+>", text)
|
||||
image_tokens = 0
|
||||
for _ in images:
|
||||
image_tokens += 258
|
||||
return len(encode(re.sub(r"<image:[A-Za-z0-9+/=]+>", '', text))[0]) + image_tokens
|
||||
|
||||
|
||||
def add_chat_picture(picture, text, visible_text):
|
||||
# resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
|
||||
max_hw, min_hw = max(picture.size), min(picture.size)
|
||||
aspect_ratio = max_hw / min_hw
|
||||
shortest_edge = int(max(300 / aspect_ratio, 224))
|
||||
longest_edge = int(shortest_edge * aspect_ratio)
|
||||
w = shortest_edge if picture.width < picture.height else longest_edge
|
||||
h = shortest_edge if picture.width >= picture.height else longest_edge
|
||||
picture = picture.resize((w,h))
|
||||
|
||||
buffer = BytesIO()
|
||||
picture.save(buffer, format="JPEG")
|
||||
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
visible = f'<img src="data:image/jpeg;base64,{img_str}">'
|
||||
internal = f'<image:{img_str}>'
|
||||
|
||||
if visible_text == '' or visible_text is None:
|
||||
visible_text = text
|
||||
|
||||
if '<image>' in text:
|
||||
text = text.replace('<image>', internal)
|
||||
else:
|
||||
text = text + '\n' + internal
|
||||
|
||||
if '<image>' in visible_text:
|
||||
visible_text = visible_text.replace('<image>', visible)
|
||||
else:
|
||||
visible_text = visible_text + '\n' + visible
|
||||
|
||||
return text, visible_text
|
||||
|
||||
|
||||
def fix_picture_after_remove_last(text, visible_text):
|
||||
image = re.search(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', text)
|
||||
if image is None:
|
||||
return text, visible_text
|
||||
if visible_text is None:
|
||||
visible_text = text
|
||||
text = re.sub(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', "<image:\\1>", text)
|
||||
return text, visible_text
|
||||
|
||||
|
||||
def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
||||
_continue = kwargs['_continue'] if '_continue' in kwargs else False
|
||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||
rows = [f"{state['context'].strip()}\n"]
|
||||
min_rows = 3
|
||||
|
||||
# Finding the maximum prompt size
|
||||
chat_prompt_size = state['chat_prompt_size']
|
||||
if shared.soft_prompt:
|
||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
||||
|
||||
prefix1 = f"{state['name1']}: "
|
||||
prefix2 = f"{state['name2']}: "
|
||||
|
||||
i = len(shared.history['internal']) - 1
|
||||
while i >= 0 and LLaVAEmbedder.len_in_tokens(''.join(rows)) < max_length:
|
||||
if _continue and i == len(shared.history['internal']) - 1:
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
|
||||
else:
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n")
|
||||
|
||||
string = shared.history['internal'][i][0]
|
||||
if string != '':
|
||||
rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n")
|
||||
|
||||
i -= 1
|
||||
|
||||
if impersonate:
|
||||
min_rows = 2
|
||||
rows.append(f"{prefix1}")
|
||||
elif not _continue:
|
||||
# Adding the user message
|
||||
if len(user_input) > 0:
|
||||
rows.append(f"{prefix1}{user_input}{state['end_of_turn']}\n")
|
||||
|
||||
# Adding the Character prefix
|
||||
rows.append(apply_extensions("bot_prefix", f"{prefix2}"))
|
||||
|
||||
while len(rows) > min_rows and LLaVAEmbedder.len_in_tokens(''.join(rows)) >= max_length:
|
||||
rows.pop(1)
|
||||
prompt = ''.join(rows)
|
||||
|
||||
if also_return_rows:
|
||||
return prompt, rows
|
||||
else:
|
||||
return prompt
|
||||
|
||||
|
||||
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
|
||||
global params
|
||||
start_ts = time.time()
|
||||
image_matches = re.finditer(r"<image:([A-Za-z0-9+/=]+)>", prompt)
|
||||
images = [Image.open(BytesIO(base64.b64decode(match.group(1)))) for match in image_matches]
|
||||
|
||||
if len(images) == 0:
|
||||
return prompt, input_ids, input_embeds
|
||||
|
||||
prompt, input_ids, input_embeds, total_embedded = llava_embedder.forward(prompt, images, state)
|
||||
print(f'LLaVA - Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
|
||||
return prompt, input_ids.unsqueeze(0).to(shared.model.device), input_embeds.unsqueeze(0).to(shared.model.device)
|
||||
|
||||
|
||||
def ui():
|
||||
global llava_embedder
|
||||
llava_embedder = LLaVAEmbedder()
|
||||
with gr.Column():
|
||||
picture_select = gr.Image(label='Send a picture', type='pil')
|
||||
# I found that it doesn't deal super well with multiple images, and demo ui had a bug where it included only the last image anyway
|
||||
single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one')
|
||||
# Prepare the input hijack
|
||||
picture_select.upload(
|
||||
lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}),
|
||||
[picture_select],
|
||||
None
|
||||
)
|
||||
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["",""]}), None, None)
|
||||
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
|
||||
shared.gradio['Generate'].click(lambda: None, None, picture_select)
|
||||
shared.gradio['textbox'].submit(lambda: None, None, picture_select)
|
||||
shared.gradio['Remove last'].click(lambda: input_hijack.update({"state": True, "value": fix_picture_after_remove_last}), None, None)
|
@ -48,3 +48,7 @@ llama-[0-9]*b-4bit$:
|
||||
.*chatglm:
|
||||
mode: 'instruct'
|
||||
instruction_template: 'ChatGLM'
|
||||
.*llava:
|
||||
mode: 'instruct'
|
||||
model_type: 'llama'
|
||||
instruction_template: 'LLaVA'
|
||||
|
@ -135,7 +135,7 @@ def load_quantized(model_name):
|
||||
# Find the model type
|
||||
if not shared.args.model_type:
|
||||
name = model_name.lower()
|
||||
if any((k in name for k in ['llama', 'alpaca', 'vicuna'])):
|
||||
if any((k in name for k in ['llama', 'alpaca', 'vicuna', 'llava'])):
|
||||
model_type = 'llama'
|
||||
elif any((k in name for k in ['opt-', 'galactica'])):
|
||||
model_type = 'opt'
|
||||
|
@ -64,7 +64,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||
rows.append(f"{this_prefix1}{user_input}{state['end_of_turn']}\n")
|
||||
|
||||
# Adding the Character prefix
|
||||
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
|
||||
rows.append(apply_extensions("bot_prefix", f"{prefix2.strip() if not is_instruct else prefix2}"))
|
||||
|
||||
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
|
||||
rows.pop(1)
|
||||
@ -127,29 +127,22 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
||||
cumulative_reply = ''
|
||||
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
|
||||
just_started = True
|
||||
visible_text = custom_generate_chat_prompt = None
|
||||
visible_text = None
|
||||
eos_token = '\n' if state['stop_at_newline'] else None
|
||||
stopping_strings = get_stopping_strings(state)
|
||||
|
||||
# Check if any extension wants to hijack this function call
|
||||
for extension, _ in extensions_module.iterator():
|
||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
||||
extension.input_hijack['state'] = False
|
||||
text, visible_text = extension.input_hijack['value']
|
||||
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
|
||||
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
|
||||
text, visible_text = apply_extensions('input_hijack', text, visible_text)
|
||||
|
||||
if visible_text is None:
|
||||
visible_text = text
|
||||
if not _continue:
|
||||
text = apply_extensions(text, "input")
|
||||
text = apply_extensions("input", text)
|
||||
|
||||
# Generating the prompt
|
||||
kwargs = {'_continue': _continue}
|
||||
if custom_generate_chat_prompt is None:
|
||||
prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs)
|
||||
if prompt is None:
|
||||
prompt = generate_chat_prompt(text, state, **kwargs)
|
||||
else:
|
||||
prompt = custom_generate_chat_prompt(text, state, **kwargs)
|
||||
|
||||
# Yield *Is typing...*
|
||||
if not any((regenerate, _continue)):
|
||||
@ -164,7 +157,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
||||
# Extracting the reply
|
||||
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||
visible_reply = apply_extensions(visible_reply, "output")
|
||||
visible_reply = apply_extensions("output", visible_reply)
|
||||
|
||||
# We need this global variable to handle the Stop event,
|
||||
# otherwise gradio gets confused
|
||||
@ -273,14 +266,14 @@ def send_last_reply_to_input():
|
||||
def replace_last_reply(text, name1, name2, mode):
|
||||
if len(shared.history['visible']) > 0:
|
||||
shared.history['visible'][-1][1] = text
|
||||
shared.history['internal'][-1][1] = apply_extensions(text, "input")
|
||||
shared.history['internal'][-1][1] = apply_extensions("input", text)
|
||||
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def send_dummy_message(text, name1, name2, mode):
|
||||
shared.history['visible'].append([text, ''])
|
||||
shared.history['internal'].append([apply_extensions(text, "input"), ''])
|
||||
shared.history['internal'].append([apply_extensions("input", text), ''])
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
@ -289,7 +282,7 @@ def send_dummy_reply(text, name1, name2, mode):
|
||||
shared.history['visible'].append(['', ''])
|
||||
shared.history['internal'].append(['', ''])
|
||||
shared.history['visible'][-1][1] = text
|
||||
shared.history['internal'][-1][1] = apply_extensions(text, "input")
|
||||
shared.history['internal'][-1][1] = apply_extensions("input", text)
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
@ -303,7 +296,7 @@ def clear_chat_log(name1, name2, greeting, mode):
|
||||
|
||||
if greeting != '':
|
||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||
shared.history['visible'] += [['', apply_extensions("output", greeting)]]
|
||||
|
||||
# Save cleared logs
|
||||
save_history(mode)
|
||||
@ -475,7 +468,7 @@ def load_character(character, name1, name2, mode):
|
||||
# Insert greeting if it exists
|
||||
if greeting != "":
|
||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||
shared.history['visible'] += [['', apply_extensions("output", greeting)]]
|
||||
|
||||
# Create .json log files since they don't already exist
|
||||
save_history(mode)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import traceback
|
||||
from functools import partial
|
||||
|
||||
import gradio as gr
|
||||
|
||||
@ -39,17 +40,60 @@ def iterator():
|
||||
|
||||
|
||||
# Extension functions that map string -> string
|
||||
def apply_extensions(text, typ):
|
||||
def _apply_string_extensions(function_name, text):
|
||||
for extension, _ in iterator():
|
||||
if typ == "input" and hasattr(extension, "input_modifier"):
|
||||
text = extension.input_modifier(text)
|
||||
elif typ == "output" and hasattr(extension, "output_modifier"):
|
||||
text = extension.output_modifier(text)
|
||||
elif typ == "bot_prefix" and hasattr(extension, "bot_prefix_modifier"):
|
||||
text = extension.bot_prefix_modifier(text)
|
||||
if hasattr(extension, function_name):
|
||||
text = getattr(extension, function_name)(text)
|
||||
return text
|
||||
|
||||
|
||||
# Input hijack of extensions
|
||||
def _apply_input_hijack(text, visible_text):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
||||
extension.input_hijack['state'] = False
|
||||
if callable(extension.input_hijack['value']):
|
||||
text, visible_text = extension.input_hijack['value'](text, visible_text)
|
||||
else:
|
||||
text, visible_text = extension.input_hijack['value']
|
||||
return text, visible_text
|
||||
|
||||
|
||||
# custom_generate_chat_prompt handling
|
||||
def _apply_custom_generate_chat_prompt(text, state, **kwargs):
|
||||
custom_generate_chat_prompt = None
|
||||
for extension, _ in iterator():
|
||||
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
|
||||
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
|
||||
if custom_generate_chat_prompt is not None:
|
||||
return custom_generate_chat_prompt(text, state, **kwargs)
|
||||
return None
|
||||
|
||||
|
||||
# Extension functions that override the default tokenizer output
|
||||
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, function_name):
|
||||
prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
|
||||
return prompt, input_ids, input_embeds
|
||||
|
||||
|
||||
EXTENSION_MAP = {
|
||||
"input": partial(_apply_string_extensions, "input_modifier"),
|
||||
"output": partial(_apply_string_extensions, "output_modifier"),
|
||||
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
|
||||
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
|
||||
"input_hijack": _apply_input_hijack,
|
||||
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt
|
||||
}
|
||||
|
||||
|
||||
def apply_extensions(typ, *args, **kwargs):
|
||||
if typ not in EXTENSION_MAP:
|
||||
raise ValueError(f"Invalid extension type {typ}")
|
||||
return EXTENSION_MAP[typ](*args, **kwargs)
|
||||
|
||||
|
||||
def create_extensions_block():
|
||||
global setup_called
|
||||
|
||||
|
@ -50,6 +50,8 @@ def find_model_type(model_name):
|
||||
return 'chatglm'
|
||||
elif 'galactica' in model_name:
|
||||
return 'galactica'
|
||||
elif 'llava' in model_name:
|
||||
return 'llava'
|
||||
elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])):
|
||||
return 'gpt4chan'
|
||||
else:
|
||||
@ -217,11 +219,12 @@ def load_model(model_name):
|
||||
tokenizer = None
|
||||
|
||||
# Try to load an universal LLaMA tokenizer
|
||||
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
||||
if p.exists():
|
||||
print(f"Loading the universal LLaMA tokenizer from {p}...")
|
||||
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
|
||||
break
|
||||
if shared.model_type != 'llava':
|
||||
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
||||
if p.exists():
|
||||
print(f"Loading the universal LLaMA tokenizer from {p}...")
|
||||
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
|
||||
break
|
||||
|
||||
# Otherwise, load it from the model folder and hope that these
|
||||
# are not outdated tokenizer files.
|
||||
|
@ -56,7 +56,7 @@ settings = {
|
||||
'chat_default_extensions': ["gallery"],
|
||||
'presets': {
|
||||
'default': 'Default',
|
||||
'.*(alpaca|llama)': "LLaMA-Precise",
|
||||
'.*(alpaca|llama|llava)': "LLaMA-Precise",
|
||||
'.*pygmalion': 'NovelAI-Storywriter',
|
||||
'.*RWKV': 'Naive',
|
||||
},
|
||||
|
@ -138,7 +138,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
|
||||
original_question = question
|
||||
if not shared.is_chat():
|
||||
question = apply_extensions(question, 'input')
|
||||
question = apply_extensions('input', question)
|
||||
|
||||
# These models are not part of Hugging Face, so we handle them
|
||||
# separately and terminate the function call earlier
|
||||
@ -155,7 +155,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
reply = shared.model.generate(context=question, **generate_params)
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
else:
|
||||
if not shared.is_chat():
|
||||
@ -166,7 +166,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
except Exception:
|
||||
@ -179,7 +179,6 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
return
|
||||
|
||||
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||
original_input_ids = input_ids
|
||||
output = input_ids[0]
|
||||
|
||||
if shared.args.verbose:
|
||||
@ -218,10 +217,16 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
generate_params.update({'synced_gpus': True})
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds)
|
||||
original_input_ids = input_ids
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
generate_params.update({'inputs': filler_input_ids})
|
||||
else:
|
||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||
original_input_ids = input_ids
|
||||
generate_params.update({'inputs': input_ids})
|
||||
if inputs_embeds is not None:
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
|
||||
try:
|
||||
# Generate the entire reply at once.
|
||||
@ -237,7 +242,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
@ -265,7 +270,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
if output[-1] in eos_token_ids:
|
||||
break
|
||||
@ -285,7 +290,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
new_tokens = len(output) - len(original_input_ids[0])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||
break
|
||||
|
@ -30,7 +30,7 @@
|
||||
],
|
||||
"presets": {
|
||||
"default": "Default",
|
||||
".*(alpaca|llama)": "LLaMA-Precise",
|
||||
".*(alpaca|llama|llava)": "LLaMA-Precise",
|
||||
".*pygmalion": "NovelAI-Storywriter",
|
||||
".*RWKV": "Naive"
|
||||
},
|
||||
|
Loading…
Reference in New Issue
Block a user