2023-03-13 17:34:26 -04:00
import os
2023-03-19 23:16:02 -04:00
import sys
2023-03-24 17:18:42 -04:00
from typing import List
2023-03-13 17:34:26 -04:00
2023-03-24 17:18:42 -04:00
import fire
2023-03-13 17:34:26 -04:00
import torch
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
import transformers
from datasets import load_dataset
"""
Unused imports :
2023-03-13 17:34:26 -04:00
import torch . nn as nn
import bitsandbytes as bnb
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
"""
2023-03-16 15:08:13 -04:00
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
from peft import ( # noqa: E402
2023-03-16 18:05:17 -04:00
LoraConfig ,
get_peft_model ,
get_peft_model_state_dict ,
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
prepare_model_for_int8_training ,
2023-03-26 20:17:54 -04:00
set_peft_model_state_dict ,
2023-03-16 18:05:17 -04:00
)
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
from transformers import LlamaForCausalLM , LlamaTokenizer # noqa: F402
2023-03-13 17:34:26 -04:00
2023-03-14 19:30:38 -04:00
2023-03-24 17:18:42 -04:00
def train (
# model/data params
base_model : str = " " , # the only required argument
2023-03-28 16:52:47 -04:00
data_path : str = " yahma/alpaca-cleaned " ,
2023-03-24 17:18:42 -04:00
output_dir : str = " ./lora-alpaca " ,
# training hyperparams
batch_size : int = 128 ,
micro_batch_size : int = 4 ,
num_epochs : int = 3 ,
learning_rate : float = 3e-4 ,
2023-03-28 11:34:36 -04:00
cutoff_len : int = 256 ,
2023-03-24 17:18:42 -04:00
val_set_size : int = 2000 ,
# lora hyperparams
lora_r : int = 8 ,
lora_alpha : int = 16 ,
lora_dropout : float = 0.05 ,
lora_target_modules : List [ str ] = [
" q_proj " ,
" v_proj " ,
] ,
# llm hyperparams
train_on_inputs : bool = True , # if False, masks out inputs in loss
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
group_by_length : bool = False , # faster, but produces an odd training loss curve
2023-03-27 13:51:36 -04:00
# wandb params
wandb_project : str = " " ,
wandb_run_name : str = " " ,
2023-03-27 14:34:23 -04:00
wandb_watch : str = " " , # options: false | gradients | all
wandb_log_model : str = " " , # options: false | true
2023-03-26 20:17:54 -04:00
resume_from_checkpoint : str = None , # either training checkpoint or final adapter
2023-03-24 17:18:42 -04:00
) :
print (
f " Training Alpaca-LoRA model with params: \n "
f " base_model: { base_model } \n "
f " data_path: { data_path } \n "
f " output_dir: { output_dir } \n "
f " batch_size: { batch_size } \n "
f " micro_batch_size: { micro_batch_size } \n "
f " num_epochs: { num_epochs } \n "
f " learning_rate: { learning_rate } \n "
f " cutoff_len: { cutoff_len } \n "
f " val_set_size: { val_set_size } \n "
f " lora_r: { lora_r } \n "
f " lora_alpha: { lora_alpha } \n "
f " lora_dropout: { lora_dropout } \n "
f " lora_target_modules: { lora_target_modules } \n "
f " train_on_inputs: { train_on_inputs } \n "
f " group_by_length: { group_by_length } \n "
2023-03-27 13:51:36 -04:00
f " wandb_project: { wandb_project } \n "
f " wandb_run_name: { wandb_run_name } \n "
f " wandb_watch: { wandb_watch } \n "
f " wandb_log_model: { wandb_log_model } \n "
2023-03-26 20:17:54 -04:00
f " resume_from_checkpoint: { resume_from_checkpoint } \n "
2023-03-24 17:18:42 -04:00
)
assert (
base_model
) , " Please specify a --base_model, e.g. --base_model= ' decapoda-research/llama-7b-hf ' "
gradient_accumulation_steps = batch_size / / micro_batch_size
device_map = " auto "
world_size = int ( os . environ . get ( " WORLD_SIZE " , 1 ) )
ddp = world_size != 1
if ddp :
device_map = { " " : int ( os . environ . get ( " LOCAL_RANK " ) or 0 ) }
gradient_accumulation_steps = gradient_accumulation_steps / / world_size
2023-03-27 13:51:36 -04:00
# Check if parameter passed or if set within environ
2023-03-27 14:34:23 -04:00
use_wandb = len ( wandb_project ) > 0 or (
" WANDB_PROJECT " in os . environ and len ( os . environ [ " WANDB_PROJECT " ] ) > 0
)
2023-03-27 13:51:36 -04:00
# Only overwrite environ if wandb param passed
2023-03-27 14:34:23 -04:00
if len ( wandb_project ) > 0 :
os . environ [ " WANDB_PROJECT " ] = wandb_project
2023-03-27 13:51:36 -04:00
if len ( wandb_watch ) > 0 :
2023-03-27 14:34:23 -04:00
os . environ [ " WANDB_WATCH " ] = wandb_watch
2023-03-27 13:51:36 -04:00
if len ( wandb_log_model ) > 0 :
2023-03-27 14:34:23 -04:00
os . environ [ " WANDB_LOG_MODEL " ] = wandb_log_model
2023-03-27 13:51:36 -04:00
2023-03-24 17:18:42 -04:00
model = LlamaForCausalLM . from_pretrained (
base_model ,
load_in_8bit = True ,
2023-03-27 15:13:35 -04:00
torch_dtype = torch . float16 ,
2023-03-24 17:18:42 -04:00
device_map = device_map ,
)
tokenizer = LlamaTokenizer . from_pretrained ( base_model )
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
tokenizer . pad_token_id = (
0 # unk. we want this to be different from the eos token
)
2023-03-24 17:18:42 -04:00
tokenizer . padding_side = " left " # Allow batched inference
def tokenize ( prompt , add_eos_token = True ) :
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer (
prompt ,
truncation = True ,
max_length = cutoff_len ,
padding = False ,
return_tensors = None ,
)
if (
result [ " input_ids " ] [ - 1 ] != tokenizer . eos_token_id
and len ( result [ " input_ids " ] ) < cutoff_len
and add_eos_token
) :
result [ " input_ids " ] . append ( tokenizer . eos_token_id )
result [ " attention_mask " ] . append ( 1 )
result [ " labels " ] = result [ " input_ids " ] . copy ( )
return result
def generate_and_tokenize_prompt ( data_point ) :
full_prompt = generate_prompt ( data_point )
tokenized_full_prompt = tokenize ( full_prompt )
if not train_on_inputs :
user_prompt = generate_prompt ( { * * data_point , " output " : " " } )
tokenized_user_prompt = tokenize ( user_prompt , add_eos_token = False )
user_prompt_len = len ( tokenized_user_prompt [ " input_ids " ] )
tokenized_full_prompt [ " labels " ] = [
- 100
] * user_prompt_len + tokenized_full_prompt [ " labels " ] [
user_prompt_len :
] # could be sped up, probably
return tokenized_full_prompt
model = prepare_model_for_int8_training ( model )
config = LoraConfig (
r = lora_r ,
lora_alpha = lora_alpha ,
target_modules = lora_target_modules ,
lora_dropout = lora_dropout ,
bias = " none " ,
task_type = " CAUSAL_LM " ,
)
model = get_peft_model ( model , config )
2023-03-24 15:46:55 -04:00
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
if data_path . endswith ( " .json " ) : # todo: support jsonl
data = load_dataset ( " json " , data_files = data_path )
else :
data = load_dataset ( data_path )
2023-03-24 15:46:55 -04:00
2023-03-26 20:17:54 -04:00
if resume_from_checkpoint :
# Check the available weights and load them
checkpoint_name = os . path . join (
resume_from_checkpoint , " pytorch_model.bin "
) # Full checkpoint
if not os . path . exists ( checkpoint_name ) :
checkpoint_name = os . path . join (
resume_from_checkpoint , " adapter_model.bin "
) # only LoRA model - LoRA config above has to fit
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
resume_from_checkpoint = (
False # So the trainer won't try loading its state
)
2023-03-26 20:17:54 -04:00
# The two files above have a different name depending on how they were saved, but are actually the same.
if os . path . exists ( checkpoint_name ) :
print ( f " Restarting from { checkpoint_name } " )
adapters_weights = torch . load ( checkpoint_name )
model = set_peft_model_state_dict ( model , adapters_weights )
2023-03-26 20:25:15 -04:00
else :
print ( f " Checkpoint { checkpoint_name } not found " )
2023-03-26 20:17:54 -04:00
model . print_trainable_parameters ( ) # Be more transparent about the % of trainable params.
2023-03-24 17:18:42 -04:00
if val_set_size > 0 :
train_val = data [ " train " ] . train_test_split (
test_size = val_set_size , shuffle = True , seed = 42
)
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
train_data = (
train_val [ " train " ] . shuffle ( ) . map ( generate_and_tokenize_prompt )
)
val_data = (
train_val [ " test " ] . shuffle ( ) . map ( generate_and_tokenize_prompt )
)
2023-03-24 17:18:42 -04:00
else :
train_data = data [ " train " ] . shuffle ( ) . map ( generate_and_tokenize_prompt )
val_data = None
2023-03-28 11:48:47 -04:00
if not ddp and torch . cuda . device_count ( ) > 1 :
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model . is_parallelizable = True
model . model_parallel = True
2023-03-24 17:18:42 -04:00
trainer = transformers . Trainer (
model = model ,
train_dataset = train_data ,
eval_dataset = val_data ,
args = transformers . TrainingArguments (
per_device_train_batch_size = micro_batch_size ,
gradient_accumulation_steps = gradient_accumulation_steps ,
warmup_steps = 100 ,
num_train_epochs = num_epochs ,
learning_rate = learning_rate ,
fp16 = True ,
logging_steps = 10 ,
2023-03-27 12:06:44 -04:00
optim = " adamw_torch " ,
2023-03-24 17:18:42 -04:00
evaluation_strategy = " steps " if val_set_size > 0 else " no " ,
save_strategy = " steps " ,
eval_steps = 200 if val_set_size > 0 else None ,
save_steps = 200 ,
output_dir = output_dir ,
save_total_limit = 3 ,
load_best_model_at_end = True if val_set_size > 0 else False ,
ddp_find_unused_parameters = False if ddp else None ,
group_by_length = group_by_length ,
2023-03-27 13:51:36 -04:00
report_to = " wandb " if use_wandb else None ,
2023-03-27 14:34:23 -04:00
run_name = wandb_run_name if use_wandb else None ,
2023-03-24 17:18:42 -04:00
) ,
data_collator = transformers . DataCollatorForSeq2Seq (
tokenizer , pad_to_multiple_of = 8 , return_tensors = " pt " , padding = True
) ,
)
model . config . use_cache = False
2023-03-13 17:34:26 -04:00
2023-03-24 17:18:42 -04:00
old_state_dict = model . state_dict
model . state_dict = (
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
lambda self , * _ , * * __ : get_peft_model_state_dict (
self , old_state_dict ( )
)
2023-03-24 17:18:42 -04:00
) . __get__ ( model , type ( model ) )
2023-03-13 17:34:26 -04:00
2023-03-24 17:18:42 -04:00
if torch . __version__ > = " 2 " and sys . platform != " win32 " :
model = torch . compile ( model )
2023-03-26 20:17:54 -04:00
trainer . train ( resume_from_checkpoint = resume_from_checkpoint )
2023-03-24 15:46:55 -04:00
2023-03-24 17:18:42 -04:00
model . save_pretrained ( output_dir )
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
print (
" \n If there ' s a warning about missing keys above, please disregard :) "
)
2023-03-13 17:34:26 -04:00
def generate_prompt ( data_point ) :
# sorry about the formatting disaster gotta move fast
2023-03-14 10:14:37 -04:00
if data_point [ " input " ] :
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
return f """ Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501
2023-03-13 17:34:26 -04:00
### Instruction:
{ data_point [ " instruction " ] }
### Input:
{ data_point [ " input " ] }
2023-03-14 00:52:06 -04:00
### Response:
{ data_point [ " output " ] } """
2023-03-13 17:34:26 -04:00
else :
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* restore default settings
* resume_from_checkpoint
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
* Print warning on checkpoint not found
* add HF dataset loading, add linters, pyproject.toml
- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md
* Default to local copy and update it
* Typo
* Remove duplicate code block
---------
Co-authored-by: Eric Wang <eric.james.wang@gmail.com>
Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
2023-03-27 13:31:44 -04:00
return f """ Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501
2023-03-13 17:34:26 -04:00
### Instruction:
{ data_point [ " instruction " ] }
2023-03-14 00:52:06 -04:00
### Response:
{ data_point [ " output " ] } """
2023-03-13 17:34:26 -04:00
2023-03-24 17:18:42 -04:00
if __name__ == " __main__ " :
fire . Fire ( train )