Add arg for bfloat16

This commit is contained in:
81300 2023-02-01 20:22:07 +02:00
parent c515282f5c
commit a6f4760772
No known key found for this signature in database

View File

@ -37,6 +37,7 @@ parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to
parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.') parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.')
parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.') parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
parser.add_argument('--nvme-offload-dir', type=str, help='Directory to use for DeepSpeed ZeRO-3 NVME offloading.') parser.add_argument('--nvme-offload-dir', type=str, help='Directory to use for DeepSpeed ZeRO-3 NVME offloading.')
parser.add_argument('--bf16', action='store_true', help='Instantiate the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
parser.add_argument('--local_rank', type=int, default=0, help='Optional argument for DeepSpeed distributed setups.') parser.add_argument('--local_rank', type=int, default=0, help='Optional argument for DeepSpeed distributed setups.')
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.') parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
@ -92,14 +93,20 @@ if args.deepspeed:
# DeepSpeed configration # DeepSpeed configration
# https://huggingface.co/docs/transformers/main_classes/deepspeed # https://huggingface.co/docs/transformers/main_classes/deepspeed
if args.bf16:
ds_fp16 = False
ds_bf16 = True
else:
ds_fp16 = True
ds_bf16 = False
train_batch_size = 1 * world_size train_batch_size = 1 * world_size
if args.nvme_offload_dir: if args.nvme_offload_dir:
ds_config = { ds_config = {
"fp16": { "fp16": {
"enabled": True, "enabled": ds_fp16,
}, },
"bf16": { "bf16": {
"enabled": False, "enabled": ds_bf16,
}, },
"zero_optimization": { "zero_optimization": {
"stage": 3, "stage": 3,
@ -135,10 +142,10 @@ if args.deepspeed:
else: else:
ds_config = { ds_config = {
"fp16": { "fp16": {
"enabled": True, "enabled": ds_fp16,
}, },
"bf16": { "bf16": {
"enabled": False, "enabled": ds_bf16,
}, },
"zero_optimization": { "zero_optimization": {
"stage": 3, "stage": 3,
@ -178,7 +185,10 @@ def load_model(model_name):
# DeepSpeed ZeRO-3 # DeepSpeed ZeRO-3
elif args.deepspeed: elif args.deepspeed:
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}")) if args.bf16:
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.bfloat16)
else:
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.float16)
model = deepspeed.initialize(model=model, model = deepspeed.initialize(model=model,
config_params=ds_config, config_params=ds_config,
model_parameters=None, model_parameters=None,