2023-03-21 17:31:25 -04:00
import sys
2023-03-15 14:11:26 -04:00
import torch
2023-03-13 20:23:29 -04:00
from peft import PeftModel
2023-03-16 15:08:13 -04:00
import transformers
2023-03-16 19:04:06 -04:00
import gradio as gr
2023-03-16 15:08:13 -04:00
assert (
" LlamaTokenizer " in transformers . _import_structure [ " models.llama " ]
) , " LLaMA is now in HuggingFace ' s main branch. \n Please reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git "
2023-03-16 10:34:33 -04:00
from transformers import LlamaTokenizer , LlamaForCausalLM , GenerationConfig
2023-03-13 18:00:05 -04:00
2023-03-16 10:34:33 -04:00
tokenizer = LlamaTokenizer . from_pretrained ( " decapoda-research/llama-7b-hf " )
2023-03-13 18:00:05 -04:00
2023-03-21 17:31:25 -04:00
LOAD_8BIT = False
2023-03-19 02:00:18 -04:00
BASE_MODEL = " decapoda-research/llama-7b-hf "
LORA_WEIGHTS = " tloen/alpaca-lora-7b "
2023-03-17 16:53:21 -04:00
if torch . cuda . is_available ( ) :
device = " cuda "
else :
device = " cpu "
2023-03-13 18:00:05 -04:00
2023-03-17 16:53:21 -04:00
try :
if torch . backends . mps . is_available ( ) :
device = " mps "
except :
pass
if device == " cuda " :
model = LlamaForCausalLM . from_pretrained (
2023-03-19 14:22:02 -04:00
BASE_MODEL ,
2023-03-21 17:31:25 -04:00
load_in_8bit = LOAD_8BIT ,
2023-03-17 16:53:21 -04:00
torch_dtype = torch . float16 ,
device_map = " auto " ,
)
2023-03-21 17:31:25 -04:00
model = PeftModel . from_pretrained (
model ,
LORA_WEIGHTS ,
torch_dtype = torch . float16 ,
)
2023-03-17 16:53:21 -04:00
elif device == " mps " :
model = LlamaForCausalLM . from_pretrained (
2023-03-19 02:00:18 -04:00
BASE_MODEL ,
2023-03-17 16:53:21 -04:00
device_map = { " " : device } ,
torch_dtype = torch . float16 ,
)
model = PeftModel . from_pretrained (
model ,
2023-03-19 02:00:18 -04:00
LORA_WEIGHTS ,
2023-03-17 16:53:21 -04:00
device_map = { " " : device } ,
torch_dtype = torch . float16 ,
)
else :
model = LlamaForCausalLM . from_pretrained (
2023-03-19 02:00:18 -04:00
BASE_MODEL , device_map = { " " : device } , low_cpu_mem_usage = True
2023-03-17 16:53:21 -04:00
)
model = PeftModel . from_pretrained (
model ,
2023-03-19 02:00:18 -04:00
LORA_WEIGHTS ,
2023-03-17 16:53:21 -04:00
device_map = { " " : device } ,
)
2023-03-15 00:33:07 -04:00
2023-03-18 19:43:53 -04:00
2023-03-15 00:33:07 -04:00
def generate_prompt ( instruction , input = None ) :
if input :
return f """ Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
2023-03-13 18:00:05 -04:00
### Instruction:
2023-03-15 00:33:07 -04:00
{ instruction }
### Input:
{ input }
2023-03-13 18:00:05 -04:00
### Response:"""
2023-03-15 00:33:07 -04:00
else :
return f """ Below is an instruction that describes a task. Write a response that appropriately completes the request.
2023-03-13 18:00:05 -04:00
2023-03-15 00:33:07 -04:00
### Instruction:
{ instruction }
2023-03-14 18:10:33 -04:00
2023-03-15 00:33:07 -04:00
### Response:"""
2023-03-21 17:31:25 -04:00
if not LOAD_8BIT :
model . half ( ) # seems to fix bugs for some users.
2023-03-15 00:33:07 -04:00
2023-03-15 20:22:22 -04:00
model . eval ( )
2023-03-21 17:31:25 -04:00
if torch . __version__ > = " 2 " and sys . platform != " win32 " :
2023-03-19 02:00:18 -04:00
model = torch . compile ( model )
2023-03-15 20:22:22 -04:00
2023-03-16 19:04:06 -04:00
def evaluate (
2023-03-18 19:43:53 -04:00
instruction ,
input = None ,
temperature = 0.1 ,
top_p = 0.75 ,
top_k = 40 ,
num_beams = 4 ,
2023-03-19 18:53:16 -04:00
max_new_tokens = 128 ,
2023-03-18 19:43:53 -04:00
* * kwargs ,
2023-03-16 19:04:06 -04:00
) :
2023-03-15 00:33:07 -04:00
prompt = generate_prompt ( instruction , input )
inputs = tokenizer ( prompt , return_tensors = " pt " )
2023-03-17 16:53:21 -04:00
input_ids = inputs [ " input_ids " ] . to ( device )
2023-03-15 00:33:07 -04:00
generation_config = GenerationConfig (
2023-03-16 19:04:06 -04:00
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
num_beams = num_beams ,
2023-03-15 00:33:07 -04:00
* * kwargs ,
)
2023-03-16 19:04:06 -04:00
with torch . no_grad ( ) :
generation_output = model . generate (
input_ids = input_ids ,
generation_config = generation_config ,
return_dict_in_generate = True ,
output_scores = True ,
2023-03-19 18:53:16 -04:00
max_new_tokens = max_new_tokens ,
2023-03-16 19:04:06 -04:00
)
2023-03-15 00:33:07 -04:00
s = generation_output . sequences [ 0 ]
output = tokenizer . decode ( s )
return output . split ( " ### Response: " ) [ 1 ] . strip ( )
2023-03-16 19:04:06 -04:00
gr . Interface (
fn = evaluate ,
inputs = [
gr . components . Textbox (
lines = 2 , label = " Instruction " , placeholder = " Tell me about alpacas. "
) ,
2023-03-18 19:43:53 -04:00
gr . components . Textbox ( lines = 2 , label = " Input " , placeholder = " none " ) ,
2023-03-16 19:04:06 -04:00
gr . components . Slider ( minimum = 0 , maximum = 1 , value = 0.1 , label = " Temperature " ) ,
gr . components . Slider ( minimum = 0 , maximum = 1 , value = 0.75 , label = " Top p " ) ,
gr . components . Slider ( minimum = 0 , maximum = 100 , step = 1 , value = 40 , label = " Top k " ) ,
2023-03-17 18:07:08 -04:00
gr . components . Slider ( minimum = 1 , maximum = 4 , step = 1 , value = 4 , label = " Beams " ) ,
2023-03-19 18:53:16 -04:00
gr . components . Slider (
minimum = 1 , maximum = 2000 , step = 1 , value = 128 , label = " Max tokens "
) ,
2023-03-16 19:04:06 -04:00
] ,
outputs = [
gr . inputs . Textbox (
lines = 5 ,
label = " Output " ,
)
] ,
title = " 🦙🌲 Alpaca-LoRA " ,
description = " Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project ' s website](https://github.com/tloen/alpaca-lora). " ,
2023-03-18 19:43:53 -04:00
) . launch ( )
2023-03-16 19:04:06 -04:00
# Old testing code follows.
"""
2023-03-15 00:33:07 -04:00
if __name__ == " __main__ " :
# testing code for readme
for instruction in [
" Tell me about alpacas. " ,
" Tell me about the president of Mexico in 2019. " ,
" Tell me about the king of France in 2019. " ,
" List all Canadian provinces in alphabetical order. " ,
" Write a Python program that prints the first 10 Fibonacci numbers. " ,
" Write a program that prints the numbers from 1 to 100. But for multiples of three print ' Fizz ' instead of the number and for the multiples of five print ' Buzz ' . For numbers which are multiples of both three and five print ' FizzBuzz ' . " ,
" Tell me five words that rhyme with ' shock ' . " ,
" Translate the sentence ' I have no mouth but I must scream ' into Spanish. " ,
2023-03-16 03:05:11 -04:00
" Count up from 1 to 500. " ,
2023-03-15 00:33:07 -04:00
] :
print ( " Instruction: " , instruction )
print ( " Response: " , evaluate ( instruction ) )
print ( )
2023-03-16 19:04:06 -04:00
"""