Remove T5 support

This commit is contained in:
oobabooga 2023-01-10 23:41:35 -03:00
parent b2a2ddcb15
commit 18ae08ef91

View File

@ -7,7 +7,7 @@ python convert-to-torch.py models/opt-1.3b
The output will be written to torch-dumps/name-of-the-model.pt The output will be written to torch-dumps/name-of-the-model.pt
''' '''
from transformers import AutoModelForCausalLM, T5ForConditionalGeneration from transformers import AutoModelForCausalLM
import torch import torch
from sys import argv from sys import argv
from pathlib import Path from pathlib import Path
@ -16,9 +16,6 @@ path = Path(argv[1])
model_name = path.name model_name = path.name
print(f"Loading {model_name}...") print(f"Loading {model_name}...")
if model_name in ['flan-t5', 't5-large']:
model = T5ForConditionalGeneration.from_pretrained(path).cuda()
else:
model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda() model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
print("Model loaded.") print("Model loaded.")