mirror of
https://github.com/tloen/alpaca-lora.git
synced 2024-10-01 01:05:56 -04:00
fix fp16 inference
This commit is contained in:
parent
052da42cbb
commit
e04897baae
14
generate.py
14
generate.py
@ -1,3 +1,4 @@
|
|||||||
|
import sys
|
||||||
import torch
|
import torch
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
import transformers
|
import transformers
|
||||||
@ -10,6 +11,7 @@ from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
|
|||||||
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
|
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
|
||||||
|
|
||||||
|
LOAD_8BIT = False
|
||||||
BASE_MODEL = "decapoda-research/llama-7b-hf"
|
BASE_MODEL = "decapoda-research/llama-7b-hf"
|
||||||
LORA_WEIGHTS = "tloen/alpaca-lora-7b"
|
LORA_WEIGHTS = "tloen/alpaca-lora-7b"
|
||||||
|
|
||||||
@ -27,11 +29,15 @@ except:
|
|||||||
if device == "cuda":
|
if device == "cuda":
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
BASE_MODEL,
|
BASE_MODEL,
|
||||||
load_in_8bit=True,
|
load_in_8bit=LOAD_8BIT,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
)
|
)
|
||||||
model = PeftModel.from_pretrained(model, LORA_WEIGHTS, torch_dtype=torch.float16)
|
model = PeftModel.from_pretrained(
|
||||||
|
model,
|
||||||
|
LORA_WEIGHTS,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
elif device == "mps":
|
elif device == "mps":
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
BASE_MODEL,
|
BASE_MODEL,
|
||||||
@ -74,9 +80,11 @@ def generate_prompt(instruction, input=None):
|
|||||||
|
|
||||||
### Response:"""
|
### Response:"""
|
||||||
|
|
||||||
|
if not LOAD_8BIT:
|
||||||
|
model.half() # seems to fix bugs for some users.
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
if torch.__version__ >= "2":
|
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||||
model = torch.compile(model)
|
model = torch.compile(model)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user