mirror of
https://github.com/tloen/alpaca-lora.git
synced 2024-10-01 01:05:56 -04:00
Add HF dataset loading, add linters, pyproject.toml (#175)
* add HF dataset loading, add linters, pyproject.toml - applied markdownlint - add black, black[jupyter], isort - fix noqa codes - add .github workflow linting - update README.md * restore default settings * resume_from_checkpoint Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com> * Print warning on checkpoint not found * add HF dataset loading, add linters, pyproject.toml - applied markdownlint - add black, black[jupyter], isort - fix noqa codes - add .github workflow linting - update README.md * Default to local copy and update it * Typo * Remove duplicate code block --------- Co-authored-by: Eric Wang <eric.james.wang@gmail.com> Co-authored-by: AngainorDev <54739135+AngainorDev@users.noreply.github.com>
This commit is contained in:
parent
b00629d773
commit
1310547f9f
33
.github/workflows/lint.yml
vendored
Normal file
33
.github/workflows/lint.yml
vendored
Normal file
@ -0,0 +1,33 @@
|
||||
name: Lint
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
run-linters:
|
||||
name: Run linters
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out Git repository
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: 3.8
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: pip install black black[jupyter] flake8
|
||||
|
||||
- name: lint isort
|
||||
run: isort --check --diff
|
||||
|
||||
- name: lint black
|
||||
run: black --check --diff
|
76
README.md
76
README.md
@ -1,4 +1,4 @@
|
||||
## 🦙🌲🤏 Alpaca-LoRA: Low-Rank LLaMA Instruct-Tuning
|
||||
# 🦙🌲🤏 Alpaca-LoRA: Low-Rank LLaMA Instruct-Tuning
|
||||
|
||||
- 🤗 **Try the pretrained model out [here](https://huggingface.co/spaces/tloen/alpaca-lora), courtesy of a GPU grant from Huggingface!**
|
||||
- Users have created a Discord server for discussion and support [here](https://discord.gg/prbq284xX5)
|
||||
@ -15,15 +15,27 @@ as well as Tim Dettmers' [bitsandbytes](https://github.com/TimDettmers/bitsandby
|
||||
|
||||
Without hyperparameter tuning, the LoRA model produces outputs comparable to the Stanford Alpaca model. (Please see the outputs included below.) Further tuning might be able to achieve better performance; I invite interested users to give it a try and report their results.
|
||||
|
||||
### Setup
|
||||
## Setup
|
||||
|
||||
1. Install dependencies
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. If bitsandbytes doesn't work, [install it from source.](https://github.com/TimDettmers/bitsandbytes/blob/main/compile_from_source.md) Windows users can follow [these instructions](https://github.com/tloen/alpaca-lora/issues/17).
|
||||
1. Set environment variables, or modify the files referencing `BASE_MODEL`:
|
||||
|
||||
```bash
|
||||
# Files referencing `BASE_MODEL`
|
||||
# export_hf_checkpoint.py
|
||||
# export_state_dict_checkpoint.py
|
||||
|
||||
export BASE_MODEL=decapoda-research/llama-7b-hf
|
||||
```
|
||||
|
||||
Both `finetune.py` and `generate.py` use `--base_model` flag as shown further below.
|
||||
|
||||
1. If bitsandbytes doesn't work, [install it from source.](https://github.com/TimDettmers/bitsandbytes/blob/main/compile_from_source.md) Windows users can follow [these instructions](https://github.com/tloen/alpaca-lora/issues/17).
|
||||
|
||||
### Training (`finetune.py`)
|
||||
|
||||
@ -36,15 +48,16 @@ Example usage:
|
||||
```bash
|
||||
python finetune.py \
|
||||
--base_model 'decapoda-research/llama-7b-hf' \
|
||||
--data_path './alpaca_data_cleaned.json' \
|
||||
--data_path 'yahma/alpaca-cleaned' \
|
||||
--output_dir './lora-alpaca'
|
||||
```
|
||||
|
||||
We can also tweak our hyperparameters:
|
||||
|
||||
```bash
|
||||
python finetune.py \
|
||||
--base_model 'decapoda-research/llama-7b-hf' \
|
||||
--data_path './alpaca_data_cleaned.json' \
|
||||
--data_path 'yahma/alpaca-cleaned' \
|
||||
--output_dir './lora-alpaca' \
|
||||
--batch_size 128 \
|
||||
--micro_batch_size 4 \
|
||||
@ -81,17 +94,6 @@ They should help users
|
||||
who want to run inference in projects like [llama.cpp](https://github.com/ggerganov/llama.cpp)
|
||||
or [alpaca.cpp](https://github.com/antimatter15/alpaca.cpp).
|
||||
|
||||
### Dataset
|
||||
|
||||
In addition to `alpaca_data.json`, which contains the original Stanford Alpaca dataset,
|
||||
we also include `alpaca_data_cleaned.json`, which has been [stripped of various tokenization artifacts](https://github.com/tloen/alpaca-lora/pull/32)
|
||||
with the help of @gururise.
|
||||
This file is now used by default in the training script.
|
||||
|
||||
@AndriyMulyar has also provided interactive, embedding-based visualizations of the original dataset's [instructions](https://atlas.nomic.ai/map/alpaca_instructions)
|
||||
and [outputs](https://atlas.nomic.ai/map/alpaca_outputs),
|
||||
as well as [clusters of bad examples](https://atlas.nomic.ai/map/d2139cc3-bc1c-441c-8d6f-3e6ffbbc2eda/838019ff-8fe2-42ba-809a-d86d2b98cd50/-18.11668742841587/-11.348087116836096/-20.88850316347706/-17.680468640801223/774455612).
|
||||
|
||||
### Notes
|
||||
|
||||
- We can likely improve our model performance significantly if we had a better dataset. Consider supporting the [LAION Open Assistant](https://open-assistant.io/) effort to produce a high-quality dataset for supervised fine-tuning (or bugging them to release their data).
|
||||
@ -105,26 +107,26 @@ as well as [clusters of bad examples](https://atlas.nomic.ai/map/d2139cc3-bc1c-4
|
||||
- [AlpacaDataCleaned](https://github.com/gururise/AlpacaDataCleaned), a project to improve the quality of the Alpaca dataset
|
||||
- Various adapter weights (download at own risk):
|
||||
- 7B:
|
||||
- https://huggingface.co/tloen/alpaca-lora-7b
|
||||
- https://huggingface.co/samwit/alpaca7B-lora
|
||||
- 🇧🇷 https://huggingface.co/22h/cabrita-lora-v0-1
|
||||
- 🇨🇳 https://huggingface.co/qychen/luotuo-lora-7b-0.1
|
||||
- 🇯🇵 https://huggingface.co/kunishou/Japanese-Alapaca-LoRA-7b-v0
|
||||
- 🇫🇷 https://huggingface.co/bofenghuang/vigogne-lora-7b
|
||||
- 🇹🇭 https://huggingface.co/Thaweewat/thai-buffala-lora-7b-v0-1
|
||||
- 🇩🇪 https://huggingface.co/thisserand/alpaca_lora_german
|
||||
- 🇮🇹 https://huggingface.co/teelinsan/camoscio-7b-llama
|
||||
- <https://huggingface.co/tloen/alpaca-lora-7b>
|
||||
- <https://huggingface.co/samwit/alpaca7B-lora>
|
||||
- 🇧🇷 <https://huggingface.co/22h/cabrita-lora-v0-1>
|
||||
- 🇨🇳 <https://huggingface.co/qychen/luotuo-lora-7b-0.1>
|
||||
- 🇯🇵 <https://huggingface.co/kunishou/Japanese-Alapaca-LoRA-7b-v0>
|
||||
- 🇫🇷 <https://huggingface.co/bofenghuang/vigogne-lora-7b>
|
||||
- 🇹🇭 <https://huggingface.co/Thaweewat/thai-buffala-lora-7b-v0-1>
|
||||
- 🇩🇪 <https://huggingface.co/thisserand/alpaca_lora_german>
|
||||
- 🇮🇹 <https://huggingface.co/teelinsan/camoscio-7b-llama>
|
||||
- 13B:
|
||||
- https://huggingface.co/chansung/alpaca-lora-13b
|
||||
- https://huggingface.co/mattreid/alpaca-lora-13b
|
||||
- https://huggingface.co/samwit/alpaca13B-lora
|
||||
- 🇯🇵 https://huggingface.co/kunishou/Japanese-Alapaca-LoRA-13b-v0
|
||||
- 🇰🇷 https://huggingface.co/chansung/koalpaca-lora-13b
|
||||
- 🇨🇳 https://huggingface.co/facat/alpaca-lora-cn-13b
|
||||
- <https://huggingface.co/chansung/alpaca-lora-13b>
|
||||
- <https://huggingface.co/mattreid/alpaca-lora-13b>
|
||||
- <https://huggingface.co/samwit/alpaca13B-lora>
|
||||
- 🇯🇵 <https://huggingface.co/kunishou/Japanese-Alapaca-LoRA-13b-v0>
|
||||
- 🇰🇷 <https://huggingface.co/chansung/koalpaca-lora-13b>
|
||||
- 🇨🇳 <https://huggingface.co/facat/alpaca-lora-cn-13b>
|
||||
- 30B:
|
||||
- https://huggingface.co/baseten/alpaca-30b
|
||||
- https://huggingface.co/chansung/alpaca-lora-30b
|
||||
- 🇯🇵 https://huggingface.co/kunishou/Japanese-Alapaca-LoRA-30b-v0
|
||||
- <https://huggingface.co/baseten/alpaca-30b>
|
||||
- <https://huggingface.co/chansung/alpaca-lora-30b>
|
||||
- 🇯🇵 <https://huggingface.co/kunishou/Japanese-Alapaca-LoRA-30b-v0>
|
||||
- [alpaca-native](https://huggingface.co/chavinlo/alpaca-native), a replication using the original Alpaca code
|
||||
|
||||
### Example outputs
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,20 +1,23 @@
|
||||
import os
|
||||
import json
|
||||
|
||||
import torch
|
||||
from peft import PeftModel, LoraConfig
|
||||
|
||||
import transformers
|
||||
from peft import PeftModel
|
||||
|
||||
# Unused imports
|
||||
# import json
|
||||
# from peft import LoraConfig
|
||||
|
||||
assert (
|
||||
"LlamaTokenizer" in transformers._import_structure["models.llama"]
|
||||
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
||||
from transformers import LlamaTokenizer, LlamaForCausalLM
|
||||
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git" # noqa: E501
|
||||
|
||||
BASE_MODEL = None
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
|
||||
|
||||
BASE_MODEL = os.environ.get("BASE_MODEL", None)
|
||||
assert (
|
||||
BASE_MODEL
|
||||
), "Please specify a BASE_MODEL in the script, e.g. 'decapoda-research/llama-7b-hf'"
|
||||
), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`" # noqa: E501
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
|
||||
|
||||
@ -35,7 +38,9 @@ lora_model = PeftModel.from_pretrained(
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
lora_weight = lora_model.base_model.model.model.layers[0].self_attn.q_proj.weight
|
||||
lora_weight = lora_model.base_model.model.model.layers[
|
||||
0
|
||||
].self_attn.q_proj.weight
|
||||
|
||||
assert torch.allclose(first_weight_old, first_weight)
|
||||
|
||||
|
@ -1,20 +1,23 @@
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
from peft import PeftModel, LoraConfig
|
||||
|
||||
import transformers
|
||||
|
||||
# Unused imports
|
||||
# from peft import LoraConfig
|
||||
from peft import PeftModel
|
||||
|
||||
assert (
|
||||
"LlamaTokenizer" in transformers._import_structure["models.llama"]
|
||||
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
||||
from transformers import LlamaTokenizer, LlamaForCausalLM
|
||||
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git" # noqa: E501
|
||||
|
||||
BASE_MODEL = None
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: E402
|
||||
|
||||
BASE_MODEL = os.environ.get("BASE_MODEL", None)
|
||||
assert (
|
||||
BASE_MODEL
|
||||
), "Please specify a BASE_MODEL in the script, e.g. 'decapoda-research/llama-7b-hf'"
|
||||
), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`" # noqa: E501
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
|
||||
|
||||
@ -54,22 +57,28 @@ n_heads = params["n_heads"]
|
||||
dim = params["dim"]
|
||||
dims_per_head = dim // n_heads
|
||||
base = 10000.0
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)
|
||||
)
|
||||
|
||||
|
||||
def permute(w):
|
||||
return (
|
||||
w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
|
||||
w.view(n_heads, dim // n_heads // 2, 2, dim)
|
||||
.transpose(1, 2)
|
||||
.reshape(dim, dim)
|
||||
)
|
||||
|
||||
|
||||
def unpermute(w):
|
||||
return (
|
||||
w.view(n_heads, 2, dim // n_heads // 2, dim).transpose(1, 2).reshape(dim, dim)
|
||||
w.view(n_heads, 2, dim // n_heads // 2, dim)
|
||||
.transpose(1, 2)
|
||||
.reshape(dim, dim)
|
||||
)
|
||||
|
||||
|
||||
def translate_state_dict_key(k):
|
||||
def translate_state_dict_key(k): # noqa: C901
|
||||
k = k.replace("base_model.model.", "")
|
||||
if k == "model.embed_tokens.weight":
|
||||
return "tok_embeddings.weight"
|
||||
|
51
finetune.py
51
finetune.py
@ -4,22 +4,28 @@ from typing import List
|
||||
|
||||
import fire
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
|
||||
"""
|
||||
Unused imports:
|
||||
import torch.nn as nn
|
||||
import bitsandbytes as bnb
|
||||
from datasets import load_dataset
|
||||
import transformers
|
||||
"""
|
||||
|
||||
# Catch when user should re-install transformers library
|
||||
assert (
|
||||
"LlamaTokenizer" in transformers._import_structure["models.llama"]
|
||||
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
from peft import (
|
||||
prepare_model_for_int8_training,
|
||||
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git" # noqa: E501
|
||||
|
||||
from peft import ( # noqa: E402
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
get_peft_model_state_dict,
|
||||
prepare_model_for_int8_training,
|
||||
set_peft_model_state_dict,
|
||||
)
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
|
||||
|
||||
|
||||
def train(
|
||||
@ -44,7 +50,7 @@ def train(
|
||||
],
|
||||
# llm hyperparams
|
||||
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
||||
group_by_length: bool = False, # faster, but produces an odd training loss curve,
|
||||
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
||||
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
||||
):
|
||||
print(
|
||||
@ -86,7 +92,9 @@ def train(
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
||||
|
||||
tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
|
||||
tokenizer.pad_token_id = (
|
||||
0 # unk. we want this to be different from the eos token
|
||||
)
|
||||
tokenizer.padding_side = "left" # Allow batched inference
|
||||
|
||||
def tokenize(prompt, add_eos_token=True):
|
||||
@ -138,7 +146,10 @@ def train(
|
||||
)
|
||||
model = get_peft_model(model, config)
|
||||
|
||||
if data_path.endswith(".json"): # todo: support jsonl
|
||||
data = load_dataset("json", data_files=data_path)
|
||||
else:
|
||||
data = load_dataset(data_path)
|
||||
|
||||
if resume_from_checkpoint:
|
||||
# Check the available weights and load them
|
||||
@ -149,7 +160,9 @@ def train(
|
||||
checkpoint_name = os.path.join(
|
||||
resume_from_checkpoint, "adapter_model.bin"
|
||||
) # only LoRA model - LoRA config above has to fit
|
||||
resume_from_checkpoint = False # So the trainer won't try loading its state
|
||||
resume_from_checkpoint = (
|
||||
False # So the trainer won't try loading its state
|
||||
)
|
||||
# The two files above have a different name depending on how they were saved, but are actually the same.
|
||||
if os.path.exists(checkpoint_name):
|
||||
print(f"Restarting from {checkpoint_name}")
|
||||
@ -164,8 +177,12 @@ def train(
|
||||
train_val = data["train"].train_test_split(
|
||||
test_size=val_set_size, shuffle=True, seed=42
|
||||
)
|
||||
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
|
||||
train_data = (
|
||||
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||
)
|
||||
val_data = (
|
||||
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
|
||||
)
|
||||
else:
|
||||
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||
val_data = None
|
||||
@ -201,7 +218,9 @@ def train(
|
||||
|
||||
old_state_dict = model.state_dict
|
||||
model.state_dict = (
|
||||
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
||||
lambda self, *_, **__: get_peft_model_state_dict(
|
||||
self, old_state_dict()
|
||||
)
|
||||
).__get__(model, type(model))
|
||||
|
||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||
@ -211,13 +230,15 @@ def train(
|
||||
|
||||
model.save_pretrained(output_dir)
|
||||
|
||||
print("\n If there's a warning about missing keys above, please disregard :)")
|
||||
print(
|
||||
"\n If there's a warning about missing keys above, please disregard :)"
|
||||
)
|
||||
|
||||
|
||||
def generate_prompt(data_point):
|
||||
# sorry about the formatting disaster gotta move fast
|
||||
if data_point["input"]:
|
||||
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
||||
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501
|
||||
|
||||
### Instruction:
|
||||
{data_point["instruction"]}
|
||||
@ -228,7 +249,7 @@ def generate_prompt(data_point):
|
||||
### Response:
|
||||
{data_point["output"]}"""
|
||||
else:
|
||||
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501
|
||||
|
||||
### Instruction:
|
||||
{data_point["instruction"]}
|
||||
|
42
generate.py
42
generate.py
@ -1,15 +1,15 @@
|
||||
import sys
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
import transformers
|
||||
import gradio as gr
|
||||
import torch
|
||||
import transformers
|
||||
from peft import PeftModel
|
||||
|
||||
assert (
|
||||
"LlamaTokenizer" in transformers._import_structure["models.llama"]
|
||||
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
||||
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
|
||||
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git" # noqa: E501
|
||||
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
@ -19,7 +19,7 @@ else:
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
except:
|
||||
except: # noqa: E722
|
||||
pass
|
||||
|
||||
|
||||
@ -28,9 +28,9 @@ def main(
|
||||
base_model: str = "",
|
||||
lora_weights: str = "tloen/alpaca-lora-7b",
|
||||
):
|
||||
assert base_model, (
|
||||
"Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
||||
)
|
||||
assert (
|
||||
base_model
|
||||
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
||||
if device == "cuda":
|
||||
@ -115,15 +115,23 @@ def main(
|
||||
fn=evaluate,
|
||||
inputs=[
|
||||
gr.components.Textbox(
|
||||
lines=2, label="Instruction", placeholder="Tell me about alpacas."
|
||||
lines=2,
|
||||
label="Instruction",
|
||||
placeholder="Tell me about alpacas.",
|
||||
),
|
||||
gr.components.Textbox(lines=2, label="Input", placeholder="none"),
|
||||
gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
|
||||
gr.components.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
|
||||
gr.components.Slider(
|
||||
minimum=0, maximum=1, value=0.1, label="Temperature"
|
||||
),
|
||||
gr.components.Slider(
|
||||
minimum=0, maximum=1, value=0.75, label="Top p"
|
||||
),
|
||||
gr.components.Slider(
|
||||
minimum=0, maximum=100, step=1, value=40, label="Top k"
|
||||
),
|
||||
gr.components.Slider(minimum=1, maximum=4, step=1, value=4, label="Beams"),
|
||||
gr.components.Slider(
|
||||
minimum=1, maximum=4, step=1, value=4, label="Beams"
|
||||
),
|
||||
gr.components.Slider(
|
||||
minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
|
||||
),
|
||||
@ -135,7 +143,7 @@ def main(
|
||||
)
|
||||
],
|
||||
title="🦙🌲 Alpaca-LoRA",
|
||||
description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).",
|
||||
description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).", # noqa: E501
|
||||
).launch()
|
||||
# Old testing code follows.
|
||||
|
||||
@ -147,7 +155,7 @@ def main(
|
||||
"Tell me about the king of France in 2019.",
|
||||
"List all Canadian provinces in alphabetical order.",
|
||||
"Write a Python program that prints the first 10 Fibonacci numbers.",
|
||||
"Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.",
|
||||
"Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.", # noqa: E501
|
||||
"Tell me five words that rhyme with 'shock'.",
|
||||
"Translate the sentence 'I have no mouth but I must scream' into Spanish.",
|
||||
"Count up from 1 to 500.",
|
||||
@ -160,7 +168,7 @@ def main(
|
||||
|
||||
def generate_prompt(instruction, input=None):
|
||||
if input:
|
||||
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
||||
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501
|
||||
|
||||
### Instruction:
|
||||
{instruction}
|
||||
@ -171,7 +179,7 @@ def generate_prompt(instruction, input=None):
|
||||
### Response:
|
||||
"""
|
||||
else:
|
||||
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501
|
||||
|
||||
### Instruction:
|
||||
{instruction}
|
||||
|
@ -22,7 +22,9 @@
|
||||
"from transformers import LlamaTokenizer\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tokenizer = LlamaTokenizer.from_pretrained(\"decapoda-research/llama-7b-hf\", add_eos_token=True)\n",
|
||||
"tokenizer = LlamaTokenizer.from_pretrained(\n",
|
||||
" \"decapoda-research/llama-7b-hf\", add_eos_token=True\n",
|
||||
")\n",
|
||||
"tokenizer.pad_token = tokenizer.eos_token\n",
|
||||
"tokenizer.pad_token_id = tokenizer.eos_token_id\n",
|
||||
"\n",
|
||||
@ -52,7 +54,9 @@
|
||||
"{data_point[\"output\"]}\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"data = data.map(lambda data_point: {\"prompt\": tokenizer(generate_prompt(data_point))})"
|
||||
"data = data.map(\n",
|
||||
" lambda data_point: {\"prompt\": tokenizer(generate_prompt(data_point))}\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
8
pyproject.toml
Normal file
8
pyproject.toml
Normal file
@ -0,0 +1,8 @@
|
||||
[tool.black]
|
||||
line-length = 79
|
||||
|
||||
[tool.isort]
|
||||
include_trailing_comma = true
|
||||
line_length = 79
|
||||
multi_line_output = 3
|
||||
profile = "black"
|
@ -1,10 +1,10 @@
|
||||
datasets
|
||||
loralib
|
||||
sentencepiece
|
||||
git+https://github.com/huggingface/transformers.git
|
||||
accelerate
|
||||
bitsandbytes
|
||||
git+https://github.com/huggingface/peft.git
|
||||
gradio
|
||||
appdirs
|
||||
bitsandbytes
|
||||
black
|
||||
black[jupyter]
|
||||
datasets
|
||||
fire
|
||||
git+https://github.com/huggingface/peft.git
|
||||
git+https://github.com/huggingface/transformers.git
|
||||
gradio
|
||||
|
Loading…
Reference in New Issue
Block a user