mirror of
https://github.com/tloen/alpaca-lora.git
synced 2024-10-01 01:05:56 -04:00
mask prompt in loss
This commit is contained in:
parent
d66908c0ca
commit
cfad895aa1
3
.gitignore
vendored
3
.gitignore
vendored
@ -5,3 +5,6 @@ __pycache__/
|
|||||||
checkpoint**
|
checkpoint**
|
||||||
minimal-llama**
|
minimal-llama**
|
||||||
upload.py
|
upload.py
|
||||||
|
lora-**
|
||||||
|
*ckpt
|
||||||
|
wandb
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
**Try the pretrained model out on Colab [here](https://colab.research.google.com/drive/1eWAmesrW99p7e1nah5bipn0zikMb8XYC)!**
|
**Try the pretrained model out on Colab [here](https://colab.research.google.com/drive/1eWAmesrW99p7e1nah5bipn0zikMb8XYC)!**
|
||||||
|
|
||||||
|
_**Update 2023-03-19:** weights have been updated with cleaned data and prompts masked out in the loss. This should reduce the number of template artifacts in outputs._
|
||||||
|
|
||||||
This repository contains code for reproducing the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) results using [low-rank adaptation (LoRA)](https://arxiv.org/pdf/2106.09685.pdf).
|
This repository contains code for reproducing the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) results using [low-rank adaptation (LoRA)](https://arxiv.org/pdf/2106.09685.pdf).
|
||||||
We provide an Instruct model of similar quality to `text-davinci-003` that can run [on a Raspberry Pi](https://twitter.com/miolini/status/1634982361757790209) (for research),
|
We provide an Instruct model of similar quality to `text-davinci-003` that can run [on a Raspberry Pi](https://twitter.com/miolini/status/1634982361757790209) (for research),
|
||||||
and the code can be easily extended to the `13b`, `30b`, and `65b` models.
|
and the code can be easily extended to the `13b`, `30b`, and `65b` models.
|
||||||
|
64
finetune.py
64
finetune.py
@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
@ -37,10 +36,10 @@ TARGET_MODULES = [
|
|||||||
DATA_PATH = "alpaca_data_cleaned.json"
|
DATA_PATH = "alpaca_data_cleaned.json"
|
||||||
|
|
||||||
device_map = "auto"
|
device_map = "auto"
|
||||||
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
ddp = world_size != 1
|
ddp = world_size != 1
|
||||||
if ddp:
|
if ddp:
|
||||||
device_map = {'':int(os.environ.get('LOCAL_RANK') or 0)}
|
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
||||||
GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size
|
GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size
|
||||||
|
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
@ -111,8 +110,60 @@ def tokenize(prompt):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
train_data = train_data.shuffle().map(lambda x: tokenize(generate_prompt(x)))
|
def generate_and_tokenize_prompt(data_point):
|
||||||
val_data = val_data.shuffle().map(lambda x: tokenize(generate_prompt(x)))
|
# This function masks out the labels for the input,
|
||||||
|
# so that our loss is computed only on the response.
|
||||||
|
user_prompt = (
|
||||||
|
(
|
||||||
|
f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
||||||
|
|
||||||
|
### Instruction:
|
||||||
|
{data_point["instruction"]}
|
||||||
|
|
||||||
|
### Input:
|
||||||
|
{data_point["input"]}
|
||||||
|
|
||||||
|
### Response:
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
if data_point["input"]
|
||||||
|
else (
|
||||||
|
f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||||
|
|
||||||
|
### Instruction:
|
||||||
|
{data_point["instruction"]}
|
||||||
|
|
||||||
|
### Response:
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
len_user_prompt_tokens = (
|
||||||
|
len(
|
||||||
|
tokenizer(
|
||||||
|
user_prompt,
|
||||||
|
truncation=True,
|
||||||
|
max_length=CUTOFF_LEN + 1,
|
||||||
|
padding="max_length",
|
||||||
|
)["input_ids"]
|
||||||
|
)
|
||||||
|
- 1
|
||||||
|
) # no eos token
|
||||||
|
full_tokens = tokenizer(
|
||||||
|
user_prompt + data_point["output"],
|
||||||
|
truncation=True,
|
||||||
|
max_length=CUTOFF_LEN + 1,
|
||||||
|
padding="max_length",
|
||||||
|
)["input_ids"][:-1]
|
||||||
|
return {
|
||||||
|
"input_ids": full_tokens,
|
||||||
|
"labels": [-100] * len_user_prompt_tokens
|
||||||
|
+ full_tokens[len_user_prompt_tokens:],
|
||||||
|
"attention_mask": [1] * (len(full_tokens)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
train_data = train_data.shuffle().map(generate_and_tokenize_prompt)
|
||||||
|
val_data = val_data.shuffle().map(generate_and_tokenize_prompt)
|
||||||
|
|
||||||
trainer = transformers.Trainer(
|
trainer = transformers.Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
@ -144,6 +195,9 @@ model.state_dict = (
|
|||||||
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
||||||
).__get__(model, type(model))
|
).__get__(model, type(model))
|
||||||
|
|
||||||
|
if torch.__version__ >= "2":
|
||||||
|
model = torch.compile(model)
|
||||||
|
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
model.save_pretrained("lora-alpaca")
|
model.save_pretrained("lora-alpaca")
|
||||||
|
Loading…
Reference in New Issue
Block a user