2023-03-13 17:34:26 -04:00
import os
2023-03-19 23:16:02 -04:00
import sys
2023-03-13 17:34:26 -04:00
import torch
import torch . nn as nn
import bitsandbytes as bnb
from datasets import load_dataset
import transformers
2023-03-16 15:08:13 -04:00
assert (
" LlamaTokenizer " in transformers . _import_structure [ " models.llama " ]
) , " LLaMA is now in HuggingFace ' s main branch. \n Please reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git "
2023-03-16 18:05:17 -04:00
from transformers import LlamaForCausalLM , LlamaTokenizer
from peft import (
prepare_model_for_int8_training ,
LoraConfig ,
get_peft_model ,
get_peft_model_state_dict ,
)
2023-03-13 17:34:26 -04:00
2023-03-14 19:30:38 -04:00
# optimized for RTX 4090. for larger GPUs, increase some of these?
MICRO_BATCH_SIZE = 4 # this could actually be 5 but i like powers of 2
BATCH_SIZE = 128
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE / / MICRO_BATCH_SIZE
2023-03-17 18:04:25 -04:00
EPOCHS = 3 # we don't always need 3 tbh
2023-03-15 00:33:07 -04:00
LEARNING_RATE = 3e-4 # the Karpathy constant
2023-03-14 19:30:38 -04:00
CUTOFF_LEN = 256 # 256 accounts for about 96% of the data
2023-03-16 01:52:54 -04:00
LORA_R = 8
2023-03-14 19:30:38 -04:00
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
2023-03-16 18:05:17 -04:00
VAL_SET_SIZE = 2000
2023-03-17 18:04:25 -04:00
TARGET_MODULES = [
" q_proj " ,
" v_proj " ,
]
DATA_PATH = " alpaca_data_cleaned.json "
2023-03-20 01:01:24 -04:00
OUTPUT_DIR = " lora-alpaca "
2023-03-14 19:30:38 -04:00
2023-03-18 01:27:33 -04:00
device_map = " auto "
2023-03-19 18:53:00 -04:00
world_size = int ( os . environ . get ( " WORLD_SIZE " , 1 ) )
2023-03-18 01:27:33 -04:00
ddp = world_size != 1
if ddp :
2023-03-19 18:53:00 -04:00
device_map = { " " : int ( os . environ . get ( " LOCAL_RANK " ) or 0 ) }
2023-03-18 01:27:33 -04:00
GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS / / world_size
2023-03-16 10:34:33 -04:00
model = LlamaForCausalLM . from_pretrained (
2023-03-13 20:23:29 -04:00
" decapoda-research/llama-7b-hf " ,
2023-03-13 17:34:26 -04:00
load_in_8bit = True ,
2023-03-18 01:27:33 -04:00
device_map = device_map ,
2023-03-13 17:34:26 -04:00
)
2023-03-16 10:34:33 -04:00
tokenizer = LlamaTokenizer . from_pretrained (
2023-03-16 01:52:54 -04:00
" decapoda-research/llama-7b-hf " , add_eos_token = True
2023-03-14 00:52:06 -04:00
)
2023-03-13 17:34:26 -04:00
model = prepare_model_for_int8_training ( model )
config = LoraConfig (
2023-03-14 19:30:38 -04:00
r = LORA_R ,
lora_alpha = LORA_ALPHA ,
2023-03-17 18:04:25 -04:00
target_modules = TARGET_MODULES ,
2023-03-14 19:30:38 -04:00
lora_dropout = LORA_DROPOUT ,
2023-03-13 17:34:26 -04:00
bias = " none " ,
task_type = " CAUSAL_LM " ,
)
model = get_peft_model ( model , config )
2023-03-14 00:52:06 -04:00
tokenizer . pad_token_id = 0 # unk. we want this to be different from the eos token
2023-03-17 18:04:25 -04:00
data = load_dataset ( " json " , data_files = DATA_PATH )
2023-03-13 17:34:26 -04:00
2023-03-16 18:05:17 -04:00
train_val = data [ " train " ] . train_test_split (
test_size = VAL_SET_SIZE , shuffle = True , seed = 42
)
train_data = train_val [ " train " ]
val_data = train_val [ " test " ]
2023-03-13 17:34:26 -04:00
def generate_prompt ( data_point ) :
# sorry about the formatting disaster gotta move fast
2023-03-14 10:14:37 -04:00
if data_point [ " input " ] :
2023-03-13 17:34:26 -04:00
return f """ Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{ data_point [ " instruction " ] }
### Input:
{ data_point [ " input " ] }
2023-03-14 00:52:06 -04:00
### Response:
{ data_point [ " output " ] } """
2023-03-13 17:34:26 -04:00
else :
return f """ Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{ data_point [ " instruction " ] }
2023-03-14 00:52:06 -04:00
### Response:
{ data_point [ " output " ] } """
2023-03-13 17:34:26 -04:00
2023-03-16 02:58:44 -04:00
def tokenize ( prompt ) :
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer (
prompt ,
2023-03-13 17:34:26 -04:00
truncation = True ,
2023-03-16 02:58:44 -04:00
max_length = CUTOFF_LEN + 1 ,
2023-03-13 17:34:26 -04:00
padding = " max_length " ,
)
2023-03-16 02:58:44 -04:00
return {
" input_ids " : result [ " input_ids " ] [ : - 1 ] ,
" attention_mask " : result [ " attention_mask " ] [ : - 1 ] ,
}
2023-03-19 18:53:00 -04:00
def generate_and_tokenize_prompt ( data_point ) :
# This function masks out the labels for the input,
# so that our loss is computed only on the response.
user_prompt = (
(
f """ Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{ data_point [ " instruction " ] }
### Input:
{ data_point [ " input " ] }
### Response:
"""
)
if data_point [ " input " ]
else (
f """ Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{ data_point [ " instruction " ] }
### Response:
"""
)
)
len_user_prompt_tokens = (
len (
tokenizer (
user_prompt ,
truncation = True ,
max_length = CUTOFF_LEN + 1 ,
padding = " max_length " ,
) [ " input_ids " ]
)
- 1
) # no eos token
full_tokens = tokenizer (
user_prompt + data_point [ " output " ] ,
truncation = True ,
max_length = CUTOFF_LEN + 1 ,
padding = " max_length " ,
) [ " input_ids " ] [ : - 1 ]
return {
" input_ids " : full_tokens ,
" labels " : [ - 100 ] * len_user_prompt_tokens
+ full_tokens [ len_user_prompt_tokens : ] ,
" attention_mask " : [ 1 ] * ( len ( full_tokens ) ) ,
}
train_data = train_data . shuffle ( ) . map ( generate_and_tokenize_prompt )
val_data = val_data . shuffle ( ) . map ( generate_and_tokenize_prompt )
2023-03-13 17:34:26 -04:00
trainer = transformers . Trainer (
model = model ,
2023-03-16 18:05:17 -04:00
train_dataset = train_data ,
eval_dataset = val_data ,
2023-03-13 17:34:26 -04:00
args = transformers . TrainingArguments (
per_device_train_batch_size = MICRO_BATCH_SIZE ,
gradient_accumulation_steps = GRADIENT_ACCUMULATION_STEPS ,
warmup_steps = 100 ,
num_train_epochs = EPOCHS ,
learning_rate = LEARNING_RATE ,
fp16 = True ,
2023-03-16 02:58:44 -04:00
logging_steps = 20 ,
2023-03-16 18:05:17 -04:00
evaluation_strategy = " steps " ,
save_strategy = " steps " ,
eval_steps = 200 ,
save_steps = 200 ,
2023-03-20 01:01:24 -04:00
output_dir = OUTPUT_DIR ,
2023-03-13 17:34:26 -04:00
save_total_limit = 3 ,
2023-03-16 18:05:17 -04:00
load_best_model_at_end = True ,
2023-03-18 01:27:33 -04:00
ddp_find_unused_parameters = False if ddp else None ,
2023-03-13 17:34:26 -04:00
) ,
data_collator = transformers . DataCollatorForLanguageModeling ( tokenizer , mlm = False ) ,
)
model . config . use_cache = False
2023-03-16 18:05:17 -04:00
old_state_dict = model . state_dict
model . state_dict = (
lambda self , * _ , * * __ : get_peft_model_state_dict ( self , old_state_dict ( ) )
) . __get__ ( model , type ( model ) )
2023-03-19 23:16:02 -04:00
if torch . __version__ > = " 2 " and sys . platform != ' win32 ' :
2023-03-19 18:53:00 -04:00
model = torch . compile ( model )
2023-03-16 18:05:17 -04:00
trainer . train ( )
2023-03-13 17:34:26 -04:00
2023-03-20 01:01:24 -04:00
model . save_pretrained ( OUTPUT_DIR )
2023-03-16 18:05:17 -04:00
print ( " \n If there ' s a warning about missing keys above, please disregard :) " )