fix fp16 inference

This commit is contained in:
Eric Wang 2023-03-21 14:31:25 -07:00
parent 052da42cbb
commit e04897baae

View File

@ -1,3 +1,4 @@
import sys
import torch import torch
from peft import PeftModel from peft import PeftModel
import transformers import transformers
@ -10,6 +11,7 @@ from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf") tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
LOAD_8BIT = False
BASE_MODEL = "decapoda-research/llama-7b-hf" BASE_MODEL = "decapoda-research/llama-7b-hf"
LORA_WEIGHTS = "tloen/alpaca-lora-7b" LORA_WEIGHTS = "tloen/alpaca-lora-7b"
@ -27,11 +29,15 @@ except:
if device == "cuda": if device == "cuda":
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
BASE_MODEL, BASE_MODEL,
load_in_8bit=True, load_in_8bit=LOAD_8BIT,
torch_dtype=torch.float16, torch_dtype=torch.float16,
device_map="auto", device_map="auto",
) )
model = PeftModel.from_pretrained(model, LORA_WEIGHTS, torch_dtype=torch.float16) model = PeftModel.from_pretrained(
model,
LORA_WEIGHTS,
torch_dtype=torch.float16,
)
elif device == "mps": elif device == "mps":
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
BASE_MODEL, BASE_MODEL,
@ -74,9 +80,11 @@ def generate_prompt(instruction, input=None):
### Response:""" ### Response:"""
if not LOAD_8BIT:
model.half() # seems to fix bugs for some users.
model.eval() model.eval()
if torch.__version__ >= "2": if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model) model = torch.compile(model)