mirror of
https://github.com/tloen/alpaca-lora.git
synced 2024-10-01 01:05:56 -04:00
mask prompt in loss
This commit is contained in:
parent
d66908c0ca
commit
cfad895aa1
3
.gitignore
vendored
3
.gitignore
vendored
@ -5,3 +5,6 @@ __pycache__/
|
||||
checkpoint**
|
||||
minimal-llama**
|
||||
upload.py
|
||||
lora-**
|
||||
*ckpt
|
||||
wandb
|
@ -2,6 +2,8 @@
|
||||
|
||||
**Try the pretrained model out on Colab [here](https://colab.research.google.com/drive/1eWAmesrW99p7e1nah5bipn0zikMb8XYC)!**
|
||||
|
||||
_**Update 2023-03-19:** weights have been updated with cleaned data and prompts masked out in the loss. This should reduce the number of template artifacts in outputs._
|
||||
|
||||
This repository contains code for reproducing the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) results using [low-rank adaptation (LoRA)](https://arxiv.org/pdf/2106.09685.pdf).
|
||||
We provide an Instruct model of similar quality to `text-davinci-003` that can run [on a Raspberry Pi](https://twitter.com/miolini/status/1634982361757790209) (for research),
|
||||
and the code can be easily extended to the `13b`, `30b`, and `65b` models.
|
||||
|
64
finetune.py
64
finetune.py
@ -1,6 +1,5 @@
|
||||
import os
|
||||
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import bitsandbytes as bnb
|
||||
@ -37,10 +36,10 @@ TARGET_MODULES = [
|
||||
DATA_PATH = "alpaca_data_cleaned.json"
|
||||
|
||||
device_map = "auto"
|
||||
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
ddp = world_size != 1
|
||||
if ddp:
|
||||
device_map = {'':int(os.environ.get('LOCAL_RANK') or 0)}
|
||||
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
||||
GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
@ -111,8 +110,60 @@ def tokenize(prompt):
|
||||
}
|
||||
|
||||
|
||||
train_data = train_data.shuffle().map(lambda x: tokenize(generate_prompt(x)))
|
||||
val_data = val_data.shuffle().map(lambda x: tokenize(generate_prompt(x)))
|
||||
def generate_and_tokenize_prompt(data_point):
|
||||
# This function masks out the labels for the input,
|
||||
# so that our loss is computed only on the response.
|
||||
user_prompt = (
|
||||
(
|
||||
f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction:
|
||||
{data_point["instruction"]}
|
||||
|
||||
### Input:
|
||||
{data_point["input"]}
|
||||
|
||||
### Response:
|
||||
"""
|
||||
)
|
||||
if data_point["input"]
|
||||
else (
|
||||
f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction:
|
||||
{data_point["instruction"]}
|
||||
|
||||
### Response:
|
||||
"""
|
||||
)
|
||||
)
|
||||
len_user_prompt_tokens = (
|
||||
len(
|
||||
tokenizer(
|
||||
user_prompt,
|
||||
truncation=True,
|
||||
max_length=CUTOFF_LEN + 1,
|
||||
padding="max_length",
|
||||
)["input_ids"]
|
||||
)
|
||||
- 1
|
||||
) # no eos token
|
||||
full_tokens = tokenizer(
|
||||
user_prompt + data_point["output"],
|
||||
truncation=True,
|
||||
max_length=CUTOFF_LEN + 1,
|
||||
padding="max_length",
|
||||
)["input_ids"][:-1]
|
||||
return {
|
||||
"input_ids": full_tokens,
|
||||
"labels": [-100] * len_user_prompt_tokens
|
||||
+ full_tokens[len_user_prompt_tokens:],
|
||||
"attention_mask": [1] * (len(full_tokens)),
|
||||
}
|
||||
|
||||
|
||||
train_data = train_data.shuffle().map(generate_and_tokenize_prompt)
|
||||
val_data = val_data.shuffle().map(generate_and_tokenize_prompt)
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
@ -144,6 +195,9 @@ model.state_dict = (
|
||||
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
||||
).__get__(model, type(model))
|
||||
|
||||
if torch.__version__ >= "2":
|
||||
model = torch.compile(model)
|
||||
|
||||
trainer.train()
|
||||
|
||||
model.save_pretrained("lora-alpaca")
|
||||
|
Loading…
Reference in New Issue
Block a user