mirror of
https://github.com/tloen/alpaca-lora.git
synced 2024-10-01 01:05:56 -04:00
fix fp16 inference
This commit is contained in:
parent
052da42cbb
commit
e04897baae
14
generate.py
14
generate.py
@ -1,3 +1,4 @@
|
||||
import sys
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
import transformers
|
||||
@ -10,6 +11,7 @@ from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
|
||||
|
||||
LOAD_8BIT = False
|
||||
BASE_MODEL = "decapoda-research/llama-7b-hf"
|
||||
LORA_WEIGHTS = "tloen/alpaca-lora-7b"
|
||||
|
||||
@ -27,11 +29,15 @@ except:
|
||||
if device == "cuda":
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
BASE_MODEL,
|
||||
load_in_8bit=True,
|
||||
load_in_8bit=LOAD_8BIT,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
model = PeftModel.from_pretrained(model, LORA_WEIGHTS, torch_dtype=torch.float16)
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
LORA_WEIGHTS,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
elif device == "mps":
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
BASE_MODEL,
|
||||
@ -74,9 +80,11 @@ def generate_prompt(instruction, input=None):
|
||||
|
||||
### Response:"""
|
||||
|
||||
if not LOAD_8BIT:
|
||||
model.half() # seems to fix bugs for some users.
|
||||
|
||||
model.eval()
|
||||
if torch.__version__ >= "2":
|
||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||
model = torch.compile(model)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user