mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Seq2Seq support (including FLAN-T5) (#1535)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
95aa43b9c2
commit
92cdb4f22b
@ -11,7 +11,8 @@ import torch
|
|||||||
import transformers
|
import transformers
|
||||||
from accelerate import infer_auto_device_map, init_empty_weights
|
from accelerate import infer_auto_device_map, init_empty_weights
|
||||||
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
||||||
AutoTokenizer, BitsAndBytesConfig, LlamaTokenizer)
|
AutoModelForSeq2SeqLM, AutoTokenizer,
|
||||||
|
BitsAndBytesConfig, LlamaTokenizer)
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import llama_attn_hijack
|
from modules import llama_attn_hijack
|
||||||
@ -55,7 +56,12 @@ def find_model_type(model_name):
|
|||||||
elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])):
|
elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])):
|
||||||
return 'gpt4chan'
|
return 'gpt4chan'
|
||||||
else:
|
else:
|
||||||
return 'HF_generic'
|
config = AutoConfig.from_pretrained(f"{shared.args.model_dir}/{model_name}")
|
||||||
|
# Not a "catch all", but fairly accurate
|
||||||
|
if config.to_dict().get("is_encoder_decoder", False):
|
||||||
|
return 'HF_seq2seq'
|
||||||
|
else:
|
||||||
|
return 'HF_generic'
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_name):
|
def load_model(model_name):
|
||||||
@ -66,6 +72,9 @@ def load_model(model_name):
|
|||||||
if shared.model_type == 'chatglm':
|
if shared.model_type == 'chatglm':
|
||||||
LoaderClass = AutoModel
|
LoaderClass = AutoModel
|
||||||
trust_remote_code = shared.args.trust_remote_code
|
trust_remote_code = shared.args.trust_remote_code
|
||||||
|
elif shared.model_type == 'HF_seq2seq':
|
||||||
|
LoaderClass = AutoModelForSeq2SeqLM
|
||||||
|
trust_remote_code = False
|
||||||
else:
|
else:
|
||||||
LoaderClass = AutoModelForCausalLM
|
LoaderClass = AutoModelForCausalLM
|
||||||
trust_remote_code = False
|
trust_remote_code = False
|
||||||
|
@ -58,12 +58,21 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||||||
|
|
||||||
|
|
||||||
def decode(output_ids, skip_special_tokens=True):
|
def decode(output_ids, skip_special_tokens=True):
|
||||||
if skip_special_tokens:
|
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
||||||
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
|
|
||||||
reply = reply.replace(r'<|endoftext|>', '')
|
|
||||||
return reply
|
def get_reply_from_output_ids(output_ids, input_ids, original_question, state):
|
||||||
|
if shared.model_type == 'HF_seq2seq':
|
||||||
|
reply = decode(output_ids, state['skip_special_tokens'])
|
||||||
|
if not shared.is_chat():
|
||||||
|
reply = apply_extensions('output', reply)
|
||||||
else:
|
else:
|
||||||
return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
|
new_tokens = len(output_ids) - len(input_ids[0])
|
||||||
|
reply = decode(output_ids[-new_tokens:], state['skip_special_tokens'])
|
||||||
|
if not shared.is_chat():
|
||||||
|
reply = original_question + apply_extensions('output', reply)
|
||||||
|
|
||||||
|
return reply
|
||||||
|
|
||||||
|
|
||||||
def generate_softprompt_input_tensors(input_ids):
|
def generate_softprompt_input_tensors(input_ids):
|
||||||
@ -262,11 +271,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||||
|
|
||||||
new_tokens = len(output) - len(input_ids[0])
|
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
|
||||||
if not shared.is_chat():
|
|
||||||
reply = original_question + apply_extensions('output', reply)
|
|
||||||
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
# Stream the reply 1 token at a time.
|
# Stream the reply 1 token at a time.
|
||||||
@ -282,7 +287,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||||||
def generate_with_streaming(**kwargs):
|
def generate_with_streaming(**kwargs):
|
||||||
return Iteratorize(generate_with_callback, kwargs, callback=None)
|
return Iteratorize(generate_with_callback, kwargs, callback=None)
|
||||||
|
|
||||||
if not shared.is_chat():
|
if not shared.is_chat() and shared.model_type != 'HF_seq2seq':
|
||||||
yield formatted_outputs(original_question, shared.model_name)
|
yield formatted_outputs(original_question, shared.model_name)
|
||||||
|
|
||||||
with generate_with_streaming(**generate_params) as generator:
|
with generate_with_streaming(**generate_params) as generator:
|
||||||
@ -290,11 +295,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||||
|
|
||||||
new_tokens = len(output) - len(input_ids[0])
|
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
|
||||||
if not shared.is_chat():
|
|
||||||
reply = original_question + apply_extensions('output', reply)
|
|
||||||
|
|
||||||
if output[-1] in eos_token_ids:
|
if output[-1] in eos_token_ids:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -310,11 +311,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||||
|
|
||||||
new_tokens = len(output) - len(original_input_ids[0])
|
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
|
||||||
if not shared.is_chat():
|
|
||||||
reply = original_question + apply_extensions('output', reply)
|
|
||||||
|
|
||||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -334,6 +331,6 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||||||
finally:
|
finally:
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
original_tokens = len(original_input_ids[0])
|
original_tokens = len(original_input_ids[0])
|
||||||
new_tokens = len(output) - original_tokens
|
new_tokens = len(output) - (original_tokens if shared.model_type != 'HF_seq2seq' else 0)
|
||||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||||
return
|
return
|
||||||
|
Loading…
Reference in New Issue
Block a user