From fac55e70f7cb28c7060f33f96a1eb3ef2d5df152 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 21 Dec 2022 13:28:19 -0300 Subject: [PATCH] Add file --- convert-to-torch.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 convert-to-torch.py diff --git a/convert-to-torch.py b/convert-to-torch.py new file mode 100644 index 00000000..beb6ad1a --- /dev/null +++ b/convert-to-torch.py @@ -0,0 +1,38 @@ +''' +Converts a transformers model to .pt, which is faster to load. + +Run with python convert.py /path/to/model/ +Make sure to write /path/to/model/ with a trailing / and not +/path/to/model + +Output will be written to torch-dumps/name-of-the-model.pt +''' + +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, AutoTokenizer, set_seed +from transformers import GPT2Tokenizer, GPT2Model, T5Tokenizer, T5ForConditionalGeneration +import torch +import sys +from sys import argv +import time +import glob +import psutil + +print(f"torch-dumps/{argv[1].split('/')[-2]}.pt") + +if argv[1].endswith('pt'): + model = OPTForCausalLM.from_pretrained(argv[1], device_map="auto") + torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt") +elif 'galactica' in argv[1].lower(): + model = OPTForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16) + #model = OPTForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, load_in_8bit=True) + torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt") +elif 'flan-t5' in argv[1].lower(): + model = T5ForConditionalGeneration.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16) + torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt") +else: + print("Loading the model") + model = AutoModelForCausalLM.from_pretrained(argv[1], low_cpu_mem_usage=True, torch_dtype=torch.float16) + print("Model loaded") + #model = AutoModelForCausalLM.from_pretrained(argv[1], device_map='auto', load_in_8bit=True) + torch.save(model, f"torch-dumps/{argv[1].split('/')[-2]}.pt") +