text-generation-webui/convert-to-torch.py
2022-12-21 13:28:19 -03:00

39 lines
1.6 KiB
Python

'''
Converts a transformers model to .pt, which is faster to load.
Run with python convert.py /path/to/model/
Make sure to write /path/to/model/ with a trailing / and not
/path/to/model
Output will be written to torch-dumps/name-of-the-model.pt
'''
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, AutoTokenizer, set_seed
from transformers import GPT2Tokenizer, GPT2Model, T5Tokenizer, T5ForConditionalGeneration
import torch
import sys
from sys import argv
import time
import glob
import psutil
print(f"torch-dumps/{argv[1].split('/')[-2]}.pt")
if argv[1].endswith('pt'):
model = OPTForCausalLM.from_pretrained(argv[1], device_map="auto")
torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
elif 'galactica' in argv[1].lower():
model = OPTForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
#model = OPTForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, load_in_8bit=True)
torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
elif 'flan-t5' in argv[1].lower():
model = T5ForConditionalGeneration.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")
else:
print("Loading the model")
model = AutoModelForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16)
print("Model loaded")
#model = AutoModelForCausalLM.from_pretrained(argv[1], device_map='auto', load_in_8bit=True)
torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt")