Seq2Seq support (including FLAN-T5) (#1535)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
Vincent Brouwers 2023-04-26 03:39:04 +02:00 committed by GitHub
parent 95aa43b9c2
commit 92cdb4f22b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 24 deletions

View File

@ -11,7 +11,8 @@ import torch
import transformers
from accelerate import infer_auto_device_map, init_empty_weights
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
AutoTokenizer, BitsAndBytesConfig, LlamaTokenizer)
AutoModelForSeq2SeqLM, AutoTokenizer,
BitsAndBytesConfig, LlamaTokenizer)
import modules.shared as shared
from modules import llama_attn_hijack
@ -54,6 +55,11 @@ def find_model_type(model_name):
return 'llava'
elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])):
return 'gpt4chan'
else:
config = AutoConfig.from_pretrained(f"{shared.args.model_dir}/{model_name}")
# Not a "catch all", but fairly accurate
if config.to_dict().get("is_encoder_decoder", False):
return 'HF_seq2seq'
else:
return 'HF_generic'
@ -66,6 +72,9 @@ def load_model(model_name):
if shared.model_type == 'chatglm':
LoaderClass = AutoModel
trust_remote_code = shared.args.trust_remote_code
elif shared.model_type == 'HF_seq2seq':
LoaderClass = AutoModelForSeq2SeqLM
trust_remote_code = False
else:
LoaderClass = AutoModelForCausalLM
trust_remote_code = False

View File

@ -58,12 +58,21 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
def decode(output_ids, skip_special_tokens=True):
if skip_special_tokens:
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
reply = reply.replace(r'<|endoftext|>', '')
return reply
return shared.tokenizer.decode(output_ids, skip_special_tokens)
def get_reply_from_output_ids(output_ids, input_ids, original_question, state):
if shared.model_type == 'HF_seq2seq':
reply = decode(output_ids, state['skip_special_tokens'])
if not shared.is_chat():
reply = apply_extensions('output', reply)
else:
return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
new_tokens = len(output_ids) - len(input_ids[0])
reply = decode(output_ids[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
return reply
def generate_softprompt_input_tensors(input_ids):
@ -262,11 +271,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
yield formatted_outputs(reply, shared.model_name)
# Stream the reply 1 token at a time.
@ -282,7 +287,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None)
if not shared.is_chat():
if not shared.is_chat() and shared.model_type != 'HF_seq2seq':
yield formatted_outputs(original_question, shared.model_name)
with generate_with_streaming(**generate_params) as generator:
@ -290,11 +295,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
if output[-1] in eos_token_ids:
break
@ -310,11 +311,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
new_tokens = len(output) - len(original_input_ids[0])
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break
@ -334,6 +331,6 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
finally:
t1 = time.time()
original_tokens = len(original_input_ids[0])
new_tokens = len(output) - original_tokens
new_tokens = len(output) - (original_tokens if shared.model_type != 'HF_seq2seq' else 0)
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return