fix fp16 inference

This commit is contained in:
Eric Wang 2023-03-21 14:31:25 -07:00
parent 052da42cbb
commit e04897baae

View File

@ -1,3 +1,4 @@
import sys
import torch
from peft import PeftModel
import transformers
@ -10,6 +11,7 @@ from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
LOAD_8BIT = False
BASE_MODEL = "decapoda-research/llama-7b-hf"
LORA_WEIGHTS = "tloen/alpaca-lora-7b"
@ -27,11 +29,15 @@ except:
if device == "cuda":
model = LlamaForCausalLM.from_pretrained(
BASE_MODEL,
load_in_8bit=True,
load_in_8bit=LOAD_8BIT,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(model, LORA_WEIGHTS, torch_dtype=torch.float16)
model = PeftModel.from_pretrained(
model,
LORA_WEIGHTS,
torch_dtype=torch.float16,
)
elif device == "mps":
model = LlamaForCausalLM.from_pretrained(
BASE_MODEL,
@ -74,9 +80,11 @@ def generate_prompt(instruction, input=None):
### Response:"""
if not LOAD_8BIT:
model.half() # seems to fix bugs for some users.
model.eval()
if torch.__version__ >= "2":
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)