2023-03-29 20:04:13 -04:00
import os
2023-03-21 17:31:25 -04:00
import sys
2023-03-24 17:18:42 -04:00
import fire
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
import gradio as gr
2023-03-15 14:11:26 -04:00
import torch
2023-03-16 15:08:13 -04:00
import transformers
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
from peft import PeftModel
from transformers import GenerationConfig , LlamaForCausalLM , LlamaTokenizer
2023-03-13 18:00:05 -04:00
2023-04-04 11:05:20 -04:00
from utils . callbacks import Iteratorize , Stream
2023-03-29 19:36:04 -04:00
from utils . prompter import Prompter
2023-03-17 16:53:21 -04:00
if torch . cuda . is_available ( ) :
device = " cuda "
else :
device = " cpu "
2023-03-13 18:00:05 -04:00
2023-03-17 16:53:21 -04:00
try :
if torch . backends . mps . is_available ( ) :
device = " mps "
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
except : # noqa: E722
2023-03-17 16:53:21 -04:00
pass
2023-03-24 17:18:42 -04:00
def main (
load_8bit : bool = False ,
base_model : str = " " ,
lora_weights : str = " tloen/alpaca-lora-7b " ,
2023-03-29 19:36:04 -04:00
prompt_template : str = " " , # The prompt template to use, will default to alpaca.
2023-03-29 20:07:21 -04:00
server_name : str = " 0.0.0.0 " , # Allows to listen on all interfaces by providing '0.
2023-03-28 12:43:29 -04:00
share_gradio : bool = False ,
2023-03-24 17:18:42 -04:00
) :
2023-03-29 20:04:13 -04:00
base_model = base_model or os . environ . get ( " BASE_MODEL " , " " )
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
assert (
base_model
2023-04-09 17:07:59 -04:00
) , " Please specify a --base_model, e.g. --base_model= ' huggyllama/llama-7b ' "
2023-03-15 00:33:07 -04:00
2023-03-29 19:36:04 -04:00
prompter = Prompter ( prompt_template )
2023-03-24 17:18:42 -04:00
tokenizer = LlamaTokenizer . from_pretrained ( base_model )
if device == " cuda " :
model = LlamaForCausalLM . from_pretrained (
base_model ,
load_in_8bit = load_8bit ,
torch_dtype = torch . float16 ,
device_map = " auto " ,
)
model = PeftModel . from_pretrained (
model ,
lora_weights ,
torch_dtype = torch . float16 ,
)
elif device == " mps " :
model = LlamaForCausalLM . from_pretrained (
base_model ,
device_map = { " " : device } ,
torch_dtype = torch . float16 ,
)
model = PeftModel . from_pretrained (
model ,
lora_weights ,
device_map = { " " : device } ,
torch_dtype = torch . float16 ,
)
else :
model = LlamaForCausalLM . from_pretrained (
base_model , device_map = { " " : device } , low_cpu_mem_usage = True
)
model = PeftModel . from_pretrained (
model ,
lora_weights ,
device_map = { " " : device } ,
)
# unwind broken decapoda-research config
model . config . pad_token_id = tokenizer . pad_token_id = 0 # unk
model . config . bos_token_id = 1
model . config . eos_token_id = 2
if not load_8bit :
model . half ( ) # seems to fix bugs for some users.
model . eval ( )
if torch . __version__ > = " 2 " and sys . platform != " win32 " :
model = torch . compile ( model )
def evaluate (
instruction ,
input = None ,
temperature = 0.1 ,
top_p = 0.75 ,
top_k = 40 ,
num_beams = 4 ,
max_new_tokens = 128 ,
2023-04-04 11:05:20 -04:00
stream_output = False ,
2023-03-24 17:18:42 -04:00
* * kwargs ,
) :
2023-03-29 19:36:04 -04:00
prompt = prompter . generate_prompt ( instruction , input )
2023-03-24 17:18:42 -04:00
inputs = tokenizer ( prompt , return_tensors = " pt " )
input_ids = inputs [ " input_ids " ] . to ( device )
generation_config = GenerationConfig (
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
num_beams = num_beams ,
* * kwargs ,
)
2023-04-04 11:05:20 -04:00
generate_params = {
" input_ids " : input_ids ,
" generation_config " : generation_config ,
" return_dict_in_generate " : True ,
" output_scores " : True ,
" max_new_tokens " : max_new_tokens ,
}
if stream_output :
# Stream the reply 1 token at a time.
# This is based on the trick of using 'stopping_criteria' to create an iterator,
# from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
def generate_with_callback ( callback = None , * * kwargs ) :
kwargs . setdefault (
" stopping_criteria " , transformers . StoppingCriteriaList ( )
)
kwargs [ " stopping_criteria " ] . append (
Stream ( callback_func = callback )
)
with torch . no_grad ( ) :
model . generate ( * * kwargs )
def generate_with_streaming ( * * kwargs ) :
return Iteratorize (
generate_with_callback , kwargs , callback = None
)
with generate_with_streaming ( * * generate_params ) as generator :
for output in generator :
# new_tokens = len(output) - len(input_ids[0])
decoded_output = tokenizer . decode ( output )
if output [ - 1 ] in [ tokenizer . eos_token_id ] :
break
yield prompter . get_response ( decoded_output )
return # early return for stream_output
# Without streaming
2023-03-24 17:18:42 -04:00
with torch . no_grad ( ) :
generation_output = model . generate (
input_ids = input_ids ,
generation_config = generation_config ,
return_dict_in_generate = True ,
output_scores = True ,
max_new_tokens = max_new_tokens ,
)
s = generation_output . sequences [ 0 ]
output = tokenizer . decode ( s )
2023-04-04 11:05:20 -04:00
yield prompter . get_response ( output )
2023-03-24 17:18:42 -04:00
gr . Interface (
fn = evaluate ,
inputs = [
gr . components . Textbox (
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
lines = 2 ,
label = " Instruction " ,
placeholder = " Tell me about alpacas. " ,
2023-03-24 17:18:42 -04:00
) ,
gr . components . Textbox ( lines = 2 , label = " Input " , placeholder = " none " ) ,
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
gr . components . Slider (
minimum = 0 , maximum = 1 , value = 0.1 , label = " Temperature "
) ,
gr . components . Slider (
minimum = 0 , maximum = 1 , value = 0.75 , label = " Top p "
) ,
2023-03-24 17:18:42 -04:00
gr . components . Slider (
minimum = 0 , maximum = 100 , step = 1 , value = 40 , label = " Top k "
) ,
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
gr . components . Slider (
minimum = 1 , maximum = 4 , step = 1 , value = 4 , label = " Beams "
) ,
2023-03-24 17:18:42 -04:00
gr . components . Slider (
minimum = 1 , maximum = 2000 , step = 1 , value = 128 , label = " Max tokens "
) ,
2023-04-04 11:05:20 -04:00
gr . components . Checkbox ( label = " Stream output " ) ,
2023-03-24 17:18:42 -04:00
] ,
outputs = [
gr . inputs . Textbox (
lines = 5 ,
label = " Output " ,
)
] ,
title = " 🦙🌲 Alpaca-LoRA " ,
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
description = " Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project ' s website](https://github.com/tloen/alpaca-lora). " , # noqa: E501
2023-04-04 11:05:20 -04:00
) . queue ( ) . launch ( server_name = " 0.0.0.0 " , share = share_gradio )
2023-03-24 17:18:42 -04:00
# Old testing code follows.
"""
# testing code for readme
for instruction in [
" Tell me about alpacas. " ,
" Tell me about the president of Mexico in 2019. " ,
" Tell me about the king of France in 2019. " ,
" List all Canadian provinces in alphabetical order. " ,
" Write a Python program that prints the first 10 Fibonacci numbers. " ,
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
" Write a program that prints the numbers from 1 to 100. But for multiples of three print ' Fizz ' instead of the number and for the multiples of five print ' Buzz ' . For numbers which are multiples of both three and five print ' FizzBuzz ' . " , # noqa: E501
2023-03-24 17:18:42 -04:00
" Tell me five words that rhyme with ' shock ' . " ,
" Translate the sentence ' I have no mouth but I must scream ' into Spanish. " ,
" Count up from 1 to 500. " ,
] :
print ( " Instruction: " , instruction )
print ( " Response: " , evaluate ( instruction ) )
print ( )
"""
2023-03-23 16:44:39 -04:00
2023-03-18 19:43:53 -04:00
2023-03-15 00:33:07 -04:00
if __name__ == " __main__ " :
2023-03-24 17:18:42 -04:00
fire . Fire ( main )