Update alpaca-lora to use transformers main branch

This commit is contained in:
andreas.echavez 2023-03-16 08:34:33 -06:00 committed by Eric Wang
parent c3d94707ec
commit 1862976b33
5 changed files with 13 additions and 16 deletions

View File

@ -16,16 +16,13 @@ Without hyperparameter tuning or validation-based checkpointing, the LoRA model
### Setup ### Setup
Until Jason Phang's [LLaMA implementation](https://github.com/huggingface/transformers/pull/21955) 1. Install dependencies
is merged, users will need to replace their local `transformers` package.
1. Install dependencies (**install zphang's transformers fork**)
``` ```
pip install -q datasets loralib sentencepiece accelerate pip install -q datasets loralib sentencepiece accelerate
pip uninstall transformers pip uninstall transformers
pip install -q git+https://github.com/zphang/transformers@c3dc391 pip install -q git+https://github.com/huggingface/transformers.git
pip install -q git+https://github.com/huggingface/peft.git pip install -q git+https://github.com/huggingface/peft.git
``` ```

View File

@ -3,11 +3,11 @@ import json
import torch import torch
from peft import PeftModel, LoraConfig from peft import PeftModel, LoraConfig
from transformers import LLaMATokenizer, LLaMAForCausalLM from transformers import LlamaTokenizer, LlamaForCausalLM
tokenizer = LLaMATokenizer.from_pretrained("decapoda-research/llama-7b-hf") tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
base_model = LLaMAForCausalLM.from_pretrained( base_model = LlamaForCausalLM.from_pretrained(
"decapoda-research/llama-7b-hf", "decapoda-research/llama-7b-hf",
load_in_8bit=False, load_in_8bit=False,
torch_dtype=torch.float16, torch_dtype=torch.float16,

View File

@ -6,7 +6,7 @@ import torch.nn as nn
import bitsandbytes as bnb import bitsandbytes as bnb
from datasets import load_dataset from datasets import load_dataset
import transformers import transformers
from transformers import AutoTokenizer, AutoConfig, LLaMAForCausalLM, LLaMATokenizer from transformers import AutoTokenizer, AutoConfig, LlamaForCausalLM, LlamaTokenizer
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model
@ -21,12 +21,12 @@ LORA_R = 8
LORA_ALPHA = 16 LORA_ALPHA = 16
LORA_DROPOUT = 0.05 LORA_DROPOUT = 0.05
model = LLaMAForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
"decapoda-research/llama-7b-hf", "decapoda-research/llama-7b-hf",
load_in_8bit=True, load_in_8bit=True,
device_map="auto", device_map="auto",
) )
tokenizer = LLaMATokenizer.from_pretrained( tokenizer = LlamaTokenizer.from_pretrained(
"decapoda-research/llama-7b-hf", add_eos_token=True "decapoda-research/llama-7b-hf", add_eos_token=True
) )

View File

@ -1,10 +1,10 @@
import torch import torch
from peft import PeftModel from peft import PeftModel
from transformers import LLaMATokenizer, LLaMAForCausalLM, GenerationConfig from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
tokenizer = LLaMATokenizer.from_pretrained("decapoda-research/llama-7b-hf") tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
model = LLaMAForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
"decapoda-research/llama-7b-hf", "decapoda-research/llama-7b-hf",
load_in_8bit=True, load_in_8bit=True,
torch_dtype=torch.float16, torch_dtype=torch.float16,

View File

@ -19,10 +19,10 @@
], ],
"source": [ "source": [
"from datasets import load_dataset\n", "from datasets import load_dataset\n",
"from transformers import LLaMATokenizer\n", "from transformers import LlamaTokenizer\n",
"\n", "\n",
"\n", "\n",
"tokenizer = LLaMATokenizer.from_pretrained(\"decapoda-research/llama-7b-hf\", add_eos_token=True)\n", "tokenizer = LlamaTokenizer.from_pretrained(\"decapoda-research/llama-7b-hf\", add_eos_token=True)\n",
"tokenizer.pad_token = tokenizer.eos_token\n", "tokenizer.pad_token = tokenizer.eos_token\n",
"tokenizer.pad_token_id = tokenizer.eos_token_id\n", "tokenizer.pad_token_id = tokenizer.eos_token_id\n",
"\n", "\n",