mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Deprecate torch dumps, move to safetensors (they load even faster)
This commit is contained in:
parent
14ffa0b418
commit
e195377050
@ -112,14 +112,6 @@ After downloading the model, follow these steps:
|
|||||||
python download-model.py EleutherAI/gpt-j-6B --text-only
|
python download-model.py EleutherAI/gpt-j-6B --text-only
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Converting to pytorch (optional)
|
|
||||||
|
|
||||||
The script `convert-to-torch.py` allows you to convert models to .pt format, which can be a lot faster to load to the GPU:
|
|
||||||
|
|
||||||
python convert-to-torch.py models/model-name
|
|
||||||
|
|
||||||
The output model will be saved to `torch-dumps/model-name.pt`. When you load a new model, the web UI first looks for this .pt file; if it is not found, it loads the model as usual from `models/model-name`.
|
|
||||||
|
|
||||||
## Starting the web UI
|
## Starting the web UI
|
||||||
|
|
||||||
conda activate textgen
|
conda activate textgen
|
||||||
|
40
convert-to-safetensors.py
Normal file
40
convert-to-safetensors.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
'''
|
||||||
|
|
||||||
|
Converts a transformers model to safetensors format and shards it.
|
||||||
|
|
||||||
|
This makes it faster to load (because of safetensors) and lowers its RAM usage
|
||||||
|
while loading (because of sharding).
|
||||||
|
|
||||||
|
Based on the original script by 81300:
|
||||||
|
|
||||||
|
https://gist.github.com/81300/fe5b08bff1cba45296a829b9d6b0f303
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from sys import argv
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
|
||||||
|
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
|
||||||
|
parser.add_argument('--output', type=str, default=None, help='Path to the output folder (default: models/{model_name}_safetensors).')
|
||||||
|
parser.add_argument("--max-shard-size", type=str, default="2GB", help="Maximum size of a shard in GB or MB (default: %(default)s).")
|
||||||
|
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
path = Path(args.MODEL)
|
||||||
|
model_name = path.name
|
||||||
|
|
||||||
|
print(f"Loading {model_name}...")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if args.bf16 else torch.float16)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(path)
|
||||||
|
|
||||||
|
out_folder = args.output or Path(f"models/{model_name}_safetensors")
|
||||||
|
print(f"Saving the converted model to {out_folder} with a maximum shard size of {args.max_shard_size}...")
|
||||||
|
model.save_pretrained(out_folder, max_shard_size=args.max_shard_size, safe_serialization=True)
|
||||||
|
tokenizer.save_pretrained(out_folder)
|
@ -1,22 +0,0 @@
|
|||||||
'''
|
|
||||||
Converts a transformers model to .pt, which is faster to load.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
python convert-to-torch.py models/opt-1.3b
|
|
||||||
|
|
||||||
The output will be written to torch-dumps/name-of-the-model.pt
|
|
||||||
'''
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from sys import argv
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import AutoModelForCausalLM
|
|
||||||
|
|
||||||
path = Path(argv[1])
|
|
||||||
model_name = path.name
|
|
||||||
|
|
||||||
print(f"Loading {model_name}...")
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
|
|
||||||
print(f"Model loaded.\nSaving to torch-dumps/{model_name}.pt")
|
|
||||||
torch.save(model, Path(f"torch-dumps/{model_name}.pt"))
|
|
@ -108,10 +108,7 @@ def load_model(model_name):
|
|||||||
|
|
||||||
# Default settings
|
# Default settings
|
||||||
if not (args.cpu or args.load_in_8bit or args.auto_devices or args.disk or args.gpu_memory is not None or args.cpu_memory is not None or args.deepspeed):
|
if not (args.cpu or args.load_in_8bit or args.auto_devices or args.disk or args.gpu_memory is not None or args.cpu_memory is not None or args.deepspeed):
|
||||||
if Path(f"torch-dumps/{model_name}.pt").exists():
|
if model_name.lower().startswith(('gpt-neo', 'opt-', 'galactica')) and any(size in model_name.lower() for size in ('13b', '20b', '30b')):
|
||||||
print("Loading in .pt format...")
|
|
||||||
model = torch.load(Path(f"torch-dumps/{model_name}.pt"))
|
|
||||||
elif model_name.lower().startswith(('gpt-neo', 'opt-', 'galactica')) and any(size in model_name.lower() for size in ('13b', '20b', '30b')):
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True)
|
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True)
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if args.bf16 else torch.float16).cuda()
|
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if args.bf16 else torch.float16).cuda()
|
||||||
@ -425,7 +422,7 @@ def update_extensions_parameters(*kwargs):
|
|||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
def get_available_models():
|
def get_available_models():
|
||||||
return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
|
return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith('.txt')], key=lambda x: x.lower())
|
||||||
|
|
||||||
def get_available_presets():
|
def get_available_presets():
|
||||||
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
|
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
|
||||||
|
Loading…
Reference in New Issue
Block a user