mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Seq2Seq support (including FLAN-T5) (#1535)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
95aa43b9c2
commit
92cdb4f22b
@ -11,7 +11,8 @@ import torch
|
||||
import transformers
|
||||
from accelerate import infer_auto_device_map, init_empty_weights
|
||||
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
||||
AutoTokenizer, BitsAndBytesConfig, LlamaTokenizer)
|
||||
AutoModelForSeq2SeqLM, AutoTokenizer,
|
||||
BitsAndBytesConfig, LlamaTokenizer)
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import llama_attn_hijack
|
||||
@ -55,7 +56,12 @@ def find_model_type(model_name):
|
||||
elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])):
|
||||
return 'gpt4chan'
|
||||
else:
|
||||
return 'HF_generic'
|
||||
config = AutoConfig.from_pretrained(f"{shared.args.model_dir}/{model_name}")
|
||||
# Not a "catch all", but fairly accurate
|
||||
if config.to_dict().get("is_encoder_decoder", False):
|
||||
return 'HF_seq2seq'
|
||||
else:
|
||||
return 'HF_generic'
|
||||
|
||||
|
||||
def load_model(model_name):
|
||||
@ -66,6 +72,9 @@ def load_model(model_name):
|
||||
if shared.model_type == 'chatglm':
|
||||
LoaderClass = AutoModel
|
||||
trust_remote_code = shared.args.trust_remote_code
|
||||
elif shared.model_type == 'HF_seq2seq':
|
||||
LoaderClass = AutoModelForSeq2SeqLM
|
||||
trust_remote_code = False
|
||||
else:
|
||||
LoaderClass = AutoModelForCausalLM
|
||||
trust_remote_code = False
|
||||
|
@ -58,12 +58,21 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
||||
|
||||
|
||||
def decode(output_ids, skip_special_tokens=True):
|
||||
if skip_special_tokens:
|
||||
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
reply = reply.replace(r'<|endoftext|>', '')
|
||||
return reply
|
||||
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
||||
|
||||
|
||||
def get_reply_from_output_ids(output_ids, input_ids, original_question, state):
|
||||
if shared.model_type == 'HF_seq2seq':
|
||||
reply = decode(output_ids, state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = apply_extensions('output', reply)
|
||||
else:
|
||||
return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
|
||||
new_tokens = len(output_ids) - len(input_ids[0])
|
||||
reply = decode(output_ids[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
return reply
|
||||
|
||||
|
||||
def generate_softprompt_input_tensors(input_ids):
|
||||
@ -262,11 +271,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
# Stream the reply 1 token at a time.
|
||||
@ -282,7 +287,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
def generate_with_streaming(**kwargs):
|
||||
return Iteratorize(generate_with_callback, kwargs, callback=None)
|
||||
|
||||
if not shared.is_chat():
|
||||
if not shared.is_chat() and shared.model_type != 'HF_seq2seq':
|
||||
yield formatted_outputs(original_question, shared.model_name)
|
||||
|
||||
with generate_with_streaming(**generate_params) as generator:
|
||||
@ -290,11 +295,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||
if output[-1] in eos_token_ids:
|
||||
break
|
||||
|
||||
@ -310,11 +311,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
new_tokens = len(output) - len(original_input_ids[0])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||
break
|
||||
|
||||
@ -334,6 +331,6 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
finally:
|
||||
t1 = time.time()
|
||||
original_tokens = len(original_input_ids[0])
|
||||
new_tokens = len(output) - original_tokens
|
||||
new_tokens = len(output) - (original_tokens if shared.model_type != 'HF_seq2seq' else 0)
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user