From cfad895aa113affb0e4bddc6d962ad08bfba582c Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Sun, 19 Mar 2023 15:53:00 -0700 Subject: [PATCH] mask prompt in loss --- .gitignore | 5 ++++- README.md | 2 ++ finetune.py | 64 ++++++++++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 65 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 0a4fb6c..40f8eee 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,7 @@ out/ __pycache__/ checkpoint** minimal-llama** -upload.py \ No newline at end of file +upload.py +lora-** +*ckpt +wandb \ No newline at end of file diff --git a/README.md b/README.md index b581006..f6c81e2 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,8 @@ **Try the pretrained model out on Colab [here](https://colab.research.google.com/drive/1eWAmesrW99p7e1nah5bipn0zikMb8XYC)!** +_**Update 2023-03-19:** weights have been updated with cleaned data and prompts masked out in the loss. This should reduce the number of template artifacts in outputs._ + This repository contains code for reproducing the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) results using [low-rank adaptation (LoRA)](https://arxiv.org/pdf/2106.09685.pdf). We provide an Instruct model of similar quality to `text-davinci-003` that can run [on a Raspberry Pi](https://twitter.com/miolini/status/1634982361757790209) (for research), and the code can be easily extended to the `13b`, `30b`, and `65b` models. diff --git a/finetune.py b/finetune.py index dbaf9fc..23afb78 100644 --- a/finetune.py +++ b/finetune.py @@ -1,6 +1,5 @@ import os -# os.environ["CUDA_VISIBLE_DEVICES"] = "0" import torch import torch.nn as nn import bitsandbytes as bnb @@ -37,10 +36,10 @@ TARGET_MODULES = [ DATA_PATH = "alpaca_data_cleaned.json" device_map = "auto" -world_size = int(os.environ.get('WORLD_SIZE', 1)) +world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 if ddp: - device_map = {'':int(os.environ.get('LOCAL_RANK') or 0)} + device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size model = LlamaForCausalLM.from_pretrained( @@ -111,8 +110,60 @@ def tokenize(prompt): } -train_data = train_data.shuffle().map(lambda x: tokenize(generate_prompt(x))) -val_data = val_data.shuffle().map(lambda x: tokenize(generate_prompt(x))) +def generate_and_tokenize_prompt(data_point): + # This function masks out the labels for the input, + # so that our loss is computed only on the response. + user_prompt = ( + ( + f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + +### Instruction: +{data_point["instruction"]} + +### Input: +{data_point["input"]} + +### Response: +""" + ) + if data_point["input"] + else ( + f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. + +### Instruction: +{data_point["instruction"]} + +### Response: +""" + ) + ) + len_user_prompt_tokens = ( + len( + tokenizer( + user_prompt, + truncation=True, + max_length=CUTOFF_LEN + 1, + padding="max_length", + )["input_ids"] + ) + - 1 + ) # no eos token + full_tokens = tokenizer( + user_prompt + data_point["output"], + truncation=True, + max_length=CUTOFF_LEN + 1, + padding="max_length", + )["input_ids"][:-1] + return { + "input_ids": full_tokens, + "labels": [-100] * len_user_prompt_tokens + + full_tokens[len_user_prompt_tokens:], + "attention_mask": [1] * (len(full_tokens)), + } + + +train_data = train_data.shuffle().map(generate_and_tokenize_prompt) +val_data = val_data.shuffle().map(generate_and_tokenize_prompt) trainer = transformers.Trainer( model=model, @@ -144,6 +195,9 @@ model.state_dict = ( lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) ).__get__(model, type(model)) +if torch.__version__ >= "2": + model = torch.compile(model) + trainer.train() model.save_pretrained("lora-alpaca")