alpaca-lora/finetune.py

207 lines
5.7 KiB
Python
Raw Normal View History

2023-03-13 17:34:26 -04:00
import os
import sys
2023-03-13 17:34:26 -04:00
import torch
import torch.nn as nn
import bitsandbytes as bnb
from datasets import load_dataset
import transformers
2023-03-16 15:08:13 -04:00
assert (
"LlamaTokenizer" in transformers._import_structure["models.llama"]
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
2023-03-16 18:05:17 -04:00
from transformers import LlamaForCausalLM, LlamaTokenizer
from peft import (
prepare_model_for_int8_training,
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
)
2023-03-13 17:34:26 -04:00
# optimized for RTX 4090. for larger GPUs, increase some of these?
MICRO_BATCH_SIZE = 4 # this could actually be 5 but i like powers of 2
BATCH_SIZE = 128
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
2023-03-17 18:04:25 -04:00
EPOCHS = 3 # we don't always need 3 tbh
LEARNING_RATE = 3e-4 # the Karpathy constant
CUTOFF_LEN = 256 # 256 accounts for about 96% of the data
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
2023-03-16 18:05:17 -04:00
VAL_SET_SIZE = 2000
2023-03-17 18:04:25 -04:00
TARGET_MODULES = [
"q_proj",
"v_proj",
]
DATA_PATH = "alpaca_data_cleaned.json"
device_map = "auto"
2023-03-19 18:53:00 -04:00
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
2023-03-19 18:53:00 -04:00
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size
model = LlamaForCausalLM.from_pretrained(
2023-03-13 20:23:29 -04:00
"decapoda-research/llama-7b-hf",
2023-03-13 17:34:26 -04:00
load_in_8bit=True,
device_map=device_map,
2023-03-13 17:34:26 -04:00
)
tokenizer = LlamaTokenizer.from_pretrained(
"decapoda-research/llama-7b-hf", add_eos_token=True
2023-03-14 00:52:06 -04:00
)
2023-03-13 17:34:26 -04:00
model = prepare_model_for_int8_training(model)
config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
2023-03-17 18:04:25 -04:00
target_modules=TARGET_MODULES,
lora_dropout=LORA_DROPOUT,
2023-03-13 17:34:26 -04:00
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
2023-03-14 00:52:06 -04:00
tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
2023-03-17 18:04:25 -04:00
data = load_dataset("json", data_files=DATA_PATH)
2023-03-13 17:34:26 -04:00
2023-03-16 18:05:17 -04:00
train_val = data["train"].train_test_split(
test_size=VAL_SET_SIZE, shuffle=True, seed=42
)
train_data = train_val["train"]
val_data = train_val["test"]
2023-03-13 17:34:26 -04:00
def generate_prompt(data_point):
# sorry about the formatting disaster gotta move fast
if data_point["input"]:
2023-03-13 17:34:26 -04:00
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{data_point["instruction"]}
### Input:
{data_point["input"]}
2023-03-14 00:52:06 -04:00
### Response:
{data_point["output"]}"""
2023-03-13 17:34:26 -04:00
else:
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{data_point["instruction"]}
2023-03-14 00:52:06 -04:00
### Response:
{data_point["output"]}"""
2023-03-13 17:34:26 -04:00
2023-03-16 02:58:44 -04:00
def tokenize(prompt):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
prompt,
2023-03-13 17:34:26 -04:00
truncation=True,
2023-03-16 02:58:44 -04:00
max_length=CUTOFF_LEN + 1,
2023-03-13 17:34:26 -04:00
padding="max_length",
)
2023-03-16 02:58:44 -04:00
return {
"input_ids": result["input_ids"][:-1],
"attention_mask": result["attention_mask"][:-1],
}
2023-03-19 18:53:00 -04:00
def generate_and_tokenize_prompt(data_point):
# This function masks out the labels for the input,
# so that our loss is computed only on the response.
user_prompt = (
(
f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{data_point["instruction"]}
### Input:
{data_point["input"]}
### Response:
"""
)
if data_point["input"]
else (
f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{data_point["instruction"]}
### Response:
"""
)
)
len_user_prompt_tokens = (
len(
tokenizer(
user_prompt,
truncation=True,
max_length=CUTOFF_LEN + 1,
padding="max_length",
)["input_ids"]
)
- 1
) # no eos token
full_tokens = tokenizer(
user_prompt + data_point["output"],
truncation=True,
max_length=CUTOFF_LEN + 1,
padding="max_length",
)["input_ids"][:-1]
return {
"input_ids": full_tokens,
"labels": [-100] * len_user_prompt_tokens
+ full_tokens[len_user_prompt_tokens:],
"attention_mask": [1] * (len(full_tokens)),
}
train_data = train_data.shuffle().map(generate_and_tokenize_prompt)
val_data = val_data.shuffle().map(generate_and_tokenize_prompt)
2023-03-13 17:34:26 -04:00
trainer = transformers.Trainer(
model=model,
2023-03-16 18:05:17 -04:00
train_dataset=train_data,
eval_dataset=val_data,
2023-03-13 17:34:26 -04:00
args=transformers.TrainingArguments(
per_device_train_batch_size=MICRO_BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
warmup_steps=100,
num_train_epochs=EPOCHS,
learning_rate=LEARNING_RATE,
fp16=True,
2023-03-16 02:58:44 -04:00
logging_steps=20,
2023-03-16 18:05:17 -04:00
evaluation_strategy="steps",
save_strategy="steps",
eval_steps=200,
save_steps=200,
2023-03-13 17:34:26 -04:00
output_dir="lora-alpaca",
save_total_limit=3,
2023-03-16 18:05:17 -04:00
load_best_model_at_end=True,
ddp_find_unused_parameters=False if ddp else None,
2023-03-13 17:34:26 -04:00
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False
2023-03-16 18:05:17 -04:00
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))
if torch.__version__ >= "2" and sys.platform != 'win32':
2023-03-19 18:53:00 -04:00
model = torch.compile(model)
2023-03-16 18:05:17 -04:00
trainer.train()
2023-03-13 17:34:26 -04:00
model.save_pretrained("lora-alpaca")
2023-03-16 18:05:17 -04:00
print("\n If there's a warning about missing keys above, please disregard :)")