diff --git a/server.py b/server.py index fbadacc5..de93dccc 100644 --- a/server.py +++ b/server.py @@ -83,10 +83,7 @@ if args.deepspeed: from modules.deepspeed_parameters import generate_ds_config # Distributed setup - if args.local_rank is not None: - local_rank = args.local_rank - else: - local_rank = int(os.getenv("LOCAL_RANK", "0")) + local_rank = args.local_rank if args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) torch.cuda.set_device(local_rank) deepspeed.init_distributed() @@ -109,15 +106,8 @@ def load_model(model_name): # DeepSpeed ZeRO-3 elif args.deepspeed: - if args.bf16: - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.bfloat16) - else: - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.float16) - model = deepspeed.initialize(model=model, - config_params=ds_config, - model_parameters=None, - optimizer=None, - lr_scheduler=None)[0] + model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.bfloat16 if args.bf16 else torch.float16) + model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] model.module.eval() # Inference print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") @@ -183,7 +173,11 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True): else: torch.cuda.empty_cache() input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens).cuda() - return input_ids + + if not args.deepspeed: + return input_ids + else: + return input_ids.to(device=local_rank) def decode(output_ids): reply = tokenizer.decode(output_ids, skip_special_tokens=True) @@ -226,10 +220,8 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok cuda = "" if args.cpu else ".cuda()" n = tokenizer.eos_token_id if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1] - if args.deepspeed: - input_ids = encode(question, tokens).to(device=local_rank) - else: - input_ids = encode(question, tokens) + input_ids = encode(question, tokens) + if stopping_string is not None: # The stopping_criteria code below was copied from # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py @@ -246,11 +238,11 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok # Generate the entire reply at once if args.no_stream: t0 = time.time() - if args.deepspeed: - with torch.no_grad(): + with torch.no_grad(): + if not args.deepspeed: + output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}") + else: output = eval(f"model.generate(input_ids, synced_gpus=True, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset})") - else: - output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}") reply = decode(output[0]) t1 = time.time() print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)") @@ -263,11 +255,11 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok yield formatted_outputs(original_question, model_name) preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=8') for i in tqdm(range(tokens//8+1)): - if args.deepspeed: - with torch.no_grad(): + with torch.no_grad(): + if not args.deepspeed: + output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}") + else: output = eval(f"model.generate(input_ids, synced_gpus=True, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset})") - else: - output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}") reply = decode(output[0]) if not (args.chat or args.cai_chat): reply = original_question + apply_extensions(reply[len(question):], "output")