text-generation-webui/convert-to-torch.py

22 lines
606 B
Python
Raw Normal View History

2022-12-21 11:28:19 -05:00
'''
Converts a transformers model to .pt, which is faster to load.
2023-01-06 22:04:52 -05:00
Example:
2023-01-07 14:54:49 -05:00
python convert-to-torch.py models/opt-1.3b
2022-12-21 11:28:19 -05:00
2023-01-07 14:54:49 -05:00
The output will be written to torch-dumps/name-of-the-model.pt
2022-12-21 11:28:19 -05:00
'''
2023-01-10 21:41:35 -05:00
from transformers import AutoModelForCausalLM
2022-12-21 11:28:19 -05:00
import torch
from sys import argv
from pathlib import Path
2022-12-21 11:28:19 -05:00
path = Path(argv[1])
model_name = path.name
2023-01-06 22:04:52 -05:00
print(f"Loading {model_name}...")
2023-01-10 21:41:35 -05:00
model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
2023-01-16 14:35:45 -05:00
print(f"Model loaded.\nSaving to torch-dumps/{model_name}.pt")
torch.save(model, Path(f"torch-dumps/{model_name}.pt"))