Fix deepspeed (oops)

This commit is contained in:
oobabooga 2023-02-02 10:39:37 -03:00
parent 90f1067598
commit f38c9bf428
2 changed files with 3 additions and 3 deletions

View File

@ -38,7 +38,7 @@ parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to
parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.')
parser.add_argument('--bf16', action='store_true', help='DeepSpeed: Instantiate the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
parser.add_argument('--local-rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
parser.add_argument('--extensions', type=str, help='The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this".')
@ -80,7 +80,7 @@ if args.settings is not None and Path(args.settings).exists():
if args.deepspeed:
import deepspeed
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_zero3_enabled
from modules.deepseed_config import generate_ds_config
from modules.deepspeed_parameters import generate_ds_config
# Distributed setup
if args.local_rank is not None:
@ -90,7 +90,7 @@ if args.deepspeed:
world_size = int(os.getenv("WORLD_SIZE", "1"))
torch.cuda.set_device(local_rank)
deepspeed.init_distributed()
ds_config = generate_ds_config(args.bf16, 1 * world_size, nvme_offload_dir)
ds_config = generate_ds_config(args.bf16, 1 * world_size, args.nvme_offload_dir)
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
def load_model(model_name):