mask prompt in loss

This commit is contained in:
Eric Wang 2023-03-19 15:53:00 -07:00
parent d66908c0ca
commit cfad895aa1
3 changed files with 65 additions and 6 deletions

3
.gitignore vendored
View File

@ -5,3 +5,6 @@ __pycache__/
checkpoint**
minimal-llama**
upload.py
lora-**
*ckpt
wandb

View File

@ -2,6 +2,8 @@
**Try the pretrained model out on Colab [here](https://colab.research.google.com/drive/1eWAmesrW99p7e1nah5bipn0zikMb8XYC)!**
_**Update 2023-03-19:** weights have been updated with cleaned data and prompts masked out in the loss. This should reduce the number of template artifacts in outputs._
This repository contains code for reproducing the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) results using [low-rank adaptation (LoRA)](https://arxiv.org/pdf/2106.09685.pdf).
We provide an Instruct model of similar quality to `text-davinci-003` that can run [on a Raspberry Pi](https://twitter.com/miolini/status/1634982361757790209) (for research),
and the code can be easily extended to the `13b`, `30b`, and `65b` models.

View File

@ -1,6 +1,5 @@
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.nn as nn
import bitsandbytes as bnb
@ -37,10 +36,10 @@ TARGET_MODULES = [
DATA_PATH = "alpaca_data_cleaned.json"
device_map = "auto"
world_size = int(os.environ.get('WORLD_SIZE', 1))
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {'':int(os.environ.get('LOCAL_RANK') or 0)}
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size
model = LlamaForCausalLM.from_pretrained(
@ -111,8 +110,60 @@ def tokenize(prompt):
}
train_data = train_data.shuffle().map(lambda x: tokenize(generate_prompt(x)))
val_data = val_data.shuffle().map(lambda x: tokenize(generate_prompt(x)))
def generate_and_tokenize_prompt(data_point):
# This function masks out the labels for the input,
# so that our loss is computed only on the response.
user_prompt = (
(
f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{data_point["instruction"]}
### Input:
{data_point["input"]}
### Response:
"""
)
if data_point["input"]
else (
f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{data_point["instruction"]}
### Response:
"""
)
)
len_user_prompt_tokens = (
len(
tokenizer(
user_prompt,
truncation=True,
max_length=CUTOFF_LEN + 1,
padding="max_length",
)["input_ids"]
)
- 1
) # no eos token
full_tokens = tokenizer(
user_prompt + data_point["output"],
truncation=True,
max_length=CUTOFF_LEN + 1,
padding="max_length",
)["input_ids"][:-1]
return {
"input_ids": full_tokens,
"labels": [-100] * len_user_prompt_tokens
+ full_tokens[len_user_prompt_tokens:],
"attention_mask": [1] * (len(full_tokens)),
}
train_data = train_data.shuffle().map(generate_and_tokenize_prompt)
val_data = val_data.shuffle().map(generate_and_tokenize_prompt)
trainer = transformers.Trainer(
model=model,
@ -144,6 +195,9 @@ model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))
if torch.__version__ >= "2":
model = torch.compile(model)
trainer.train()
model.save_pretrained("lora-alpaca")