mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Add TensorRT-LLM support (#5715)
This commit is contained in:
parent
536f8d58d4
commit
577a8cd3ee
27
docker/TensorRT-LLM/Dockerfile
Normal file
27
docker/TensorRT-LLM/Dockerfile
Normal file
@ -0,0 +1,27 @@
|
||||
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-runtime
|
||||
|
||||
# Install Git
|
||||
RUN apt update && apt install -y git
|
||||
|
||||
# System-wide TensorRT-LLM requirements
|
||||
RUN apt install -y openmpi-bin libopenmpi-dev
|
||||
|
||||
# Set the working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install text-generation-webui
|
||||
RUN git clone https://github.com/oobabooga/text-generation-webui
|
||||
WORKDIR /app/text-generation-webui
|
||||
RUN pip install -r requirements.txt
|
||||
|
||||
# This is needed to avoid an error about "Failed to build mpi4py" in the next command
|
||||
ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
|
||||
|
||||
# Install TensorRT-LLM
|
||||
RUN pip3 install tensorrt_llm==0.10.0 -U --pre --extra-index-url https://pypi.nvidia.com
|
||||
|
||||
# Expose the necessary port for the Python server
|
||||
EXPOSE 7860 5000
|
||||
|
||||
# Run the Python server.py script with the specified command
|
||||
CMD ["python", "server.py", "--api", "--listen"]
|
@ -131,6 +131,11 @@ loaders_and_params = OrderedDict({
|
||||
'hqq_backend',
|
||||
'trust_remote_code',
|
||||
'no_use_fast',
|
||||
],
|
||||
'TensorRT-LLM': [
|
||||
'max_seq_len',
|
||||
'cpp_runner',
|
||||
'tensorrt_llm_info',
|
||||
]
|
||||
})
|
||||
|
||||
@ -316,6 +321,16 @@ loaders_samplers = {
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
},
|
||||
'TensorRT-LLM': {
|
||||
'temperature',
|
||||
'top_p',
|
||||
'top_k',
|
||||
'repetition_penalty',
|
||||
'presence_penalty',
|
||||
'frequency_penalty',
|
||||
'ban_eos_token',
|
||||
'auto_max_new_tokens',
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -77,6 +77,7 @@ def load_model(model_name, loader=None):
|
||||
'ExLlamav2_HF': ExLlamav2_HF_loader,
|
||||
'AutoAWQ': AutoAWQ_loader,
|
||||
'HQQ': HQQ_loader,
|
||||
'TensorRT-LLM': TensorRT_LLM_loader,
|
||||
}
|
||||
|
||||
metadata = get_model_metadata(model_name)
|
||||
@ -101,7 +102,7 @@ def load_model(model_name, loader=None):
|
||||
tokenizer = load_tokenizer(model_name, model)
|
||||
|
||||
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
|
||||
if loader.lower().startswith('exllama'):
|
||||
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt'):
|
||||
shared.settings['truncation_length'] = shared.args.max_seq_len
|
||||
elif loader in ['llama.cpp', 'llamacpp_HF']:
|
||||
shared.settings['truncation_length'] = shared.args.n_ctx
|
||||
@ -337,6 +338,13 @@ def HQQ_loader(model_name):
|
||||
return model
|
||||
|
||||
|
||||
def TensorRT_LLM_loader(model_name):
|
||||
from modules.tensorrt_llm import TensorRTLLMModel
|
||||
|
||||
model = TensorRTLLMModel.from_pretrained(model_name)
|
||||
return model
|
||||
|
||||
|
||||
def get_max_memory_dict():
|
||||
max_memory = {}
|
||||
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
|
||||
|
@ -81,6 +81,9 @@ def get_model_metadata(model):
|
||||
# Transformers metadata
|
||||
if hf_metadata is not None:
|
||||
metadata = json.loads(open(path, 'r', encoding='utf-8').read())
|
||||
if 'pretrained_config' in metadata:
|
||||
metadata = metadata['pretrained_config']
|
||||
|
||||
for k in ['max_position_embeddings', 'model_max_length', 'max_seq_len']:
|
||||
if k in metadata:
|
||||
model_settings['truncation_length'] = metadata[k]
|
||||
|
@ -165,6 +165,10 @@ group.add_argument('--no_inject_fused_attention', action='store_true', help='Dis
|
||||
group = parser.add_argument_group('HQQ')
|
||||
group.add_argument('--hqq-backend', type=str, default='PYTORCH_COMPILE', help='Backend for the HQQ loader. Valid options: PYTORCH, PYTORCH_COMPILE, ATEN.')
|
||||
|
||||
# TensorRT-LLM
|
||||
group = parser.add_argument_group('TensorRT-LLM')
|
||||
group.add_argument('--cpp-runner', action='store_true', help='Use the ModelRunnerCpp runner, which is faster than the default ModelRunner but doesn\'t support streaming yet.')
|
||||
|
||||
# DeepSpeed
|
||||
group = parser.add_argument_group('DeepSpeed')
|
||||
group.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
|
||||
@ -263,6 +267,8 @@ def fix_loader_name(name):
|
||||
return 'AutoAWQ'
|
||||
elif name in ['hqq']:
|
||||
return 'HQQ'
|
||||
elif name in ['tensorrt', 'tensorrtllm', 'tensorrt_llm', 'tensorrt-llm', 'tensort', 'tensortllm']:
|
||||
return 'TensorRT-LLM'
|
||||
|
||||
|
||||
def add_extension(name, last=False):
|
||||
|
131
modules/tensorrt_llm.py
Normal file
131
modules/tensorrt_llm.py
Normal file
@ -0,0 +1,131 @@
|
||||
from pathlib import Path
|
||||
|
||||
import tensorrt_llm
|
||||
import torch
|
||||
from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.text_generation import (
|
||||
get_max_prompt_length,
|
||||
get_reply_from_output_ids
|
||||
)
|
||||
|
||||
|
||||
class TensorRTLLMModel:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, path_to_model):
|
||||
|
||||
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
|
||||
# Define model settings
|
||||
runner_kwargs = dict(
|
||||
engine_dir=str(path_to_model),
|
||||
lora_dir=None,
|
||||
rank=runtime_rank,
|
||||
debug_mode=False,
|
||||
lora_ckpt_source="hf",
|
||||
)
|
||||
|
||||
if shared.args.cpp_runner:
|
||||
logger.info("TensorRT-LLM: Using \"ModelRunnerCpp\"")
|
||||
runner_kwargs.update(
|
||||
max_batch_size=1,
|
||||
max_input_len=shared.args.max_seq_len - 512,
|
||||
max_output_len=512,
|
||||
max_beam_width=1,
|
||||
max_attention_window_size=None,
|
||||
sink_token_length=None,
|
||||
)
|
||||
else:
|
||||
logger.info("TensorRT-LLM: Using \"ModelRunner\"")
|
||||
|
||||
# Load the model
|
||||
runner_cls = ModelRunnerCpp if shared.args.cpp_runner else ModelRunner
|
||||
runner = runner_cls.from_dir(**runner_kwargs)
|
||||
|
||||
result = self()
|
||||
result.model = runner
|
||||
result.runtime_rank = runtime_rank
|
||||
|
||||
return result
|
||||
|
||||
def generate_with_streaming(self, prompt, state):
|
||||
batch_input_ids = []
|
||||
input_ids = shared.tokenizer.encode(
|
||||
prompt,
|
||||
add_special_tokens=True,
|
||||
truncation=False,
|
||||
)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int32)
|
||||
input_ids = input_ids[-get_max_prompt_length(state):] # Apply truncation_length
|
||||
batch_input_ids.append(input_ids)
|
||||
|
||||
if shared.args.cpp_runner:
|
||||
max_new_tokens = min(512, state['max_new_tokens'])
|
||||
elif state['auto_max_new_tokens']:
|
||||
max_new_tokens = state['truncation_length'] - input_ids.shape[-1]
|
||||
else:
|
||||
max_new_tokens = state['max_new_tokens']
|
||||
|
||||
with torch.no_grad():
|
||||
generator = self.model.generate(
|
||||
batch_input_ids,
|
||||
max_new_tokens=max_new_tokens,
|
||||
max_attention_window_size=None,
|
||||
sink_token_length=None,
|
||||
end_id=shared.tokenizer.eos_token_id if not state['ban_eos_token'] else -1,
|
||||
pad_id=shared.tokenizer.pad_token_id or shared.tokenizer.eos_token_id,
|
||||
temperature=state['temperature'],
|
||||
top_k=state['top_k'],
|
||||
top_p=state['top_p'],
|
||||
num_beams=1,
|
||||
length_penalty=1.0,
|
||||
repetition_penalty=state['repetition_penalty'],
|
||||
presence_penalty=state['presence_penalty'],
|
||||
frequency_penalty=state['frequency_penalty'],
|
||||
stop_words_list=None,
|
||||
bad_words_list=None,
|
||||
lora_uids=None,
|
||||
prompt_table_path=None,
|
||||
prompt_tasks=None,
|
||||
streaming=not shared.args.cpp_runner,
|
||||
output_sequence_lengths=True,
|
||||
return_dict=True,
|
||||
medusa_choices=None
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
cumulative_reply = ''
|
||||
starting_from = batch_input_ids[0].shape[-1]
|
||||
|
||||
if shared.args.cpp_runner:
|
||||
sequence_length = generator['sequence_lengths'][0].item()
|
||||
output_ids = generator['output_ids'][0][0][:sequence_length].tolist()
|
||||
|
||||
cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from)
|
||||
starting_from = sequence_length
|
||||
yield cumulative_reply
|
||||
else:
|
||||
for curr_outputs in generator:
|
||||
if shared.stop_everything:
|
||||
break
|
||||
|
||||
sequence_length = curr_outputs['sequence_lengths'][0].item()
|
||||
output_ids = curr_outputs['output_ids'][0][0][:sequence_length].tolist()
|
||||
|
||||
cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from)
|
||||
starting_from = sequence_length
|
||||
yield cumulative_reply
|
||||
|
||||
def generate(self, prompt, state):
|
||||
output = ''
|
||||
for output in self.generate_with_streaming(prompt, state):
|
||||
pass
|
||||
|
||||
return output
|
@ -54,7 +54,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
||||
yield ''
|
||||
return
|
||||
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model']:
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'TensorRTLLMModel']:
|
||||
generate_func = generate_reply_custom
|
||||
else:
|
||||
generate_func = generate_reply_HF
|
||||
@ -132,7 +132,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
||||
if shared.tokenizer is None:
|
||||
raise ValueError('No tokenizer is loaded')
|
||||
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model']:
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'TensorRTLLMModel']:
|
||||
input_ids = shared.tokenizer.encode(str(prompt))
|
||||
if shared.model.__class__.__name__ not in ['Exllamav2Model']:
|
||||
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
||||
@ -158,7 +158,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
||||
if truncation_length is not None:
|
||||
input_ids = input_ids[:, -truncation_length:]
|
||||
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model'] or shared.args.cpu:
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu:
|
||||
return input_ids
|
||||
elif shared.args.deepspeed:
|
||||
import deepspeed
|
||||
|
@ -106,6 +106,7 @@ def list_model_elements():
|
||||
'streaming_llm',
|
||||
'attention_sink_size',
|
||||
'hqq_backend',
|
||||
'cpp_runner',
|
||||
]
|
||||
if is_torch_xpu_available():
|
||||
for i in range(torch.xpu.device_count()):
|
||||
|
@ -139,6 +139,7 @@ def create_ui():
|
||||
shared.gradio['autosplit'] = gr.Checkbox(label="autosplit", value=shared.args.autosplit, info='Automatically split the model tensors across the available GPUs.')
|
||||
shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.')
|
||||
shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Necessary to use CFG with this loader.')
|
||||
shared.gradio['cpp_runner'] = gr.Checkbox(label="cpp-runner", value=shared.args.cpp_runner, info='Enable inference with ModelRunnerCpp, which is faster than the default ModelRunner.')
|
||||
shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.')
|
||||
with gr.Blocks():
|
||||
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Set trust_remote_code=True while loading the tokenizer/model. To enable this option, start the web UI with the --trust-remote-code flag.', interactive=shared.args.trust_remote_code)
|
||||
@ -149,6 +150,7 @@ def create_ui():
|
||||
shared.gradio['disable_exllamav2'] = gr.Checkbox(label="disable_exllamav2", value=shared.args.disable_exllamav2, info='Disable ExLlamav2 kernel for GPTQ models.')
|
||||
shared.gradio['exllamav2_info'] = gr.Markdown("ExLlamav2_HF is recommended over ExLlamav2 for better integration with extensions and more consistent sampling behavior across loaders.")
|
||||
shared.gradio['llamacpp_HF_info'] = gr.Markdown("llamacpp_HF loads llama.cpp as a Transformers model. To use it, you need to place your GGUF in a subfolder of models/ with the necessary tokenizer files.\n\nYou can use the \"llamacpp_HF creator\" menu to do that automatically.")
|
||||
shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `max_seq_len` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.')
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
|
Loading…
Reference in New Issue
Block a user