Merge pull request #335 from nomic-ai/gptj

GPT-J
This commit is contained in:
Benjamin Schmidt 2023-04-13 16:59:09 -04:00 committed by GitHub
commit 51264f5eac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 802 additions and 196 deletions

3
.gitignore vendored
View File

@ -1,3 +1,6 @@
*.pkl
ckpts*
.deepspeed_env
*.jsonl *.jsonl
*tar.gz *tar.gz
ckpts** ckpts**

3
.gitmodules vendored
View File

@ -1,6 +1,3 @@
[submodule "transformers"]
path = transformers
url = https://github.com/huggingface/transformers.git
[submodule "peft"] [submodule "peft"]
path = peft path = peft
url = https://github.com/huggingface/peft.git url = https://github.com/huggingface/peft.git

17
GPT-J_MAP.md Normal file
View File

@ -0,0 +1,17 @@
# Inference on Training Data
## Run Inference
```bash
torchrun --master_port=29085 --nproc-per-node 8 inference.py --config=configs/inference/gptj.yaml
```
## Visualizations
```bash
python build_map.py
```
will build a map in `Atlas`, one using the internal clustering algorithm provided by Nomic and one using the embeddings generated by the finetuned model.

View File

@ -1,8 +1,11 @@
<h1 align="center">GPT4All</h1> <h1 align="center">GPT4All</h1>
<p align="center">Demo, data, and code to train an assistant-style large language model with ~800k GPT-3.5-Turbo Generations based on LLaMa</p> <p align="center">Demo, data, and code to train open-source assistant-style large language model based on GPT-J and LLaMa</p>
<p align="center">
<a href="https://s3.amazonaws.com/static.nomic.ai/gpt4all/2023_GPT4All-J_Technical_Report_2.pdf">:green_book: Technical Report 2: GPT4All-J </a>
</p>
<p align="center"> <p align="center">
<a href="https://s3.amazonaws.com/static.nomic.ai/gpt4all/2023_GPT4All_Technical_Report.pdf">:green_book: Technical Report</a> <a href="https://s3.amazonaws.com/static.nomic.ai/gpt4all/2023_GPT4All_Technical_Report.pdf">:green_book: Technical Report 1: GPT4All</a>
</p> </p>
<p align="center"> <p align="center">
@ -13,6 +16,23 @@
<a href="https://github.com/nomic-ai/gpt4all-ts">:computer: Official Typescript Bindings</a> <a href="https://github.com/nomic-ai/gpt4all-ts">:computer: Official Typescript Bindings</a>
</p> </p>
<p align="center">
<a href="https://github.com/nomic-ai/gpt4all-ui">:speech_balloon: Official Web Chat Interface</a>
</p>
<p align="center">
<a href="https://python.langchain.com/en/latest/modules/models/llms/integrations/gpt4all.html">🦜️🔗 Official Langchain Backend</a>
</p>
<p align="center">
<a href="https://discord.gg/mGZE39AS3e">Discord</a>
</p>
<p align="center">
<a href="https://github.com/nomic-ai/gpt4all-ts">:computer: Official Typescript Bindings</a>
</p>
<p align="center"> <p align="center">
<a href="https://github.com/nomic-ai/gpt4all-ui">:speech_balloon: Official Chat Interface</a> <a href="https://github.com/nomic-ai/gpt4all-ui">:speech_balloon: Official Chat Interface</a>
</p> </p>
@ -27,6 +47,56 @@
</p> </p>
<p align="center">
GPT4All is made possible by our compute partner <a href="https://www.paperspace.com/">Paperspace</a>.
</p>
## GPT4All-J: An Apache-2 Licensed GPT4All Model
![gpt4all-j-demo](https://user-images.githubusercontent.com/13879686/231876409-e3de1934-93bb-4b4b-9013-b491a969ebbc.gif)
Run runs on an M1 Mac (not sped up!)
### GPT4All-J Chat UI Installers
Installs a native chat-client with auto-update functionality that runs on your desktop with the GPT4All-J model baked into it.
[Mac/OSX](https://gtp4all.io/installers/gpt4all-0.1.0-Darwin.dmg)
[Windows](https://gpt4all.io/installers/gpt4all-0.1.0-win64.exe)
[Ubuntu](https://gpt4all.io/installers/gpt4all-0.1.0-Linux.run)
These files are not yet cert signed by Windows/Apple so you will see security warnings on initial installation. We did not want to delay release while waiting for their process to complete.
Find the most up-to-date information on the [GPT4All Website](https://gpt4all.io/)
### Raw Model
[ggml Model Download Link](https://gpt4all.io/ggml-gpt4all-j.bin)
Note this model is only compatible with the C++ bindings found [here](https://github.com/nomic-ai/gpt4all-chat). It will not work with any existing llama.cpp bindings as we had to do a large fork of llama.cpp. GPT4All will support the ecosystem around this new C++ backend going forward.
Python bindings are imminent and will be integrated into this [repository](https://github.com/nomic-ai/pyllamacpp). Stay tuned on the [GPT4All discord](https://discord.gg/mGZE39AS3e) for updates.
## Training GPT4All-J
Please see [GPT4All-J Technical Report]() for details.
### GPT4All-J Training Data
- We are releasing the curated training data for anyone to replicate GPT4All-J here: [GPT4All-J Training Data](https://huggingface.co/datasets/nomic-ai/gpt4all-j-prompt-generations)
- [Atlas Map of Prompts](https://atlas.nomic.ai/map/gpt4all-j-prompts-curated)
- [Atlas Map of Responses](https://atlas.nomic.ai/map/gpt4all-j-response-curated)
### GPT4All-J Training Instructions
```bash
accelerate launch --dynamo_backend=inductor --num_processes=8 --num_machines=1 --machine_rank=0 --deepspeed_multinode_launcher standard --mixed_precision=bf16 --use_deepspeed --deepspeed_config_file=configs/deepspeed/ds_config_gptj.json train.py --config configs/train/finetune_gptj.yaml
```
# Original GPT4All Model (based on GPL Licensed LLaMa)
@ -113,8 +183,8 @@ Feel free to convert this to a more structured table.
# Roadmap # Roadmap
## Short Term ## Short Term
- <span style="color:green">(IN PROGRESS)</span> Train a GPT4All model based on GPTJ to alleviate llama distribution issues. - <span style="color:green">(Done)</span> Train a GPT4All model based on GPTJ to alleviate llama distribution issues.
- <span style="color:green">(IN PROGRESS)</span> Create improved CPU and GPU interfaces for this model. - <span style="color:green">(Done)</span> Create improved CPU and GPU interfaces for this model.
- <span style="color:green">(Done)</span> [Integrate llama.cpp bindings](https://github.com/nomic-ai/pyllamacpp) - <span style="color:green">(Done)</span> [Integrate llama.cpp bindings](https://github.com/nomic-ai/pyllamacpp)
- <span style="color:green">(Done)</span> [Create a good conversational chat interface for the model.](https://github.com/nomic-ai/gpt4all-ui) - <span style="color:green">(Done)</span> [Create a good conversational chat interface for the model.](https://github.com/nomic-ai/gpt4all-ui)
- <span style="color:green">(Done)</span> [Allow users to opt in and submit their chats for subsequent training runs](https://github.com/nomic-ai/gpt4all-ui) - <span style="color:green">(Done)</span> [Allow users to opt in and submit their chats for subsequent training runs](https://github.com/nomic-ai/gpt4all-ui)
@ -122,7 +192,7 @@ Feel free to convert this to a more structured table.
## Medium Term ## Medium Term
- <span style="color:red">(NOT STARTED)</span> Integrate GPT4All with [Atlas](https://atlas.nomic.ai) to allow for document retrieval. - <span style="color:red">(NOT STARTED)</span> Integrate GPT4All with [Atlas](https://atlas.nomic.ai) to allow for document retrieval.
- BLOCKED by GPT4All based on GPTJ - BLOCKED by GPT4All based on GPTJ
- <span style="color:red">(NOT STARTED)</span> Integrate GPT4All with Langchain. - <span style="color:red">(Done)</span> Integrate GPT4All with Langchain.
- <span style="color:green">(IN PROGRESS)</span> Build easy custom training scripts to allow users to fine tune models. - <span style="color:green">(IN PROGRESS)</span> Build easy custom training scripts to allow users to fine tune models.
## Long Term ## Long Term
@ -131,9 +201,11 @@ Feel free to convert this to a more structured table.
# Reproducibility # Reproducibility
Trained LoRa Weights: Trained Model Weights:
- gpt4all-lora (four full epochs of training): https://huggingface.co/nomic-ai/gpt4all-lora - gpt4all-lora (four full epochs of training): https://huggingface.co/nomic-ai/gpt4all-lora
- gpt4all-lora-epoch-2 (three full epochs of training) https://huggingface.co/nomic-ai/gpt4all-lora-epoch-2 - gpt4all-lora-epoch-2 (three full epochs of training) https://huggingface.co/nomic-ai/gpt4all-lora-epoch-2
- gpt4all-j (one full epoch of training) (https://huggingface.co/nomic-ai/gpt4all-j)
- gpt4all-j-lora (one full epoch of training) (https://huggingface.co/nomic-ai/gpt4all-j-lora)
Raw Data: Raw Data:
- [Training Data Without P3](https://huggingface.co/datasets/nomic-ai/gpt4all_prompt_generations) - [Training Data Without P3](https://huggingface.co/datasets/nomic-ai/gpt4all_prompt_generations)
@ -159,9 +231,6 @@ Setup the environment
``` ```
python -m pip install -r requirements.txt python -m pip install -r requirements.txt
cd transformers
pip install -e .
cd ../peft cd ../peft
pip install -e . pip install -e .
``` ```

View File

@ -23,7 +23,7 @@ We used the initial parameters:
| Weight decay | 0 | | Weight decay | 0 |
| Warmup Steps | 100 | | Warmup Steps | 100 |
We randomly shuffle and set aside %5 of the data for validation. We randomly shuffle and set aside 5% of the data for validation.
We had an initial bug in logging the training loss but we noticed a decrease in validation loss. We had an initial bug in logging the training loss but we noticed a decrease in validation loss.
@ -235,3 +235,49 @@ Taking inspiration from [the Alpaca Repo](https://github.com/tatsu-lab/stanford_
Comparing our model LoRa to the [Alpaca LoRa](https://huggingface.co/tloen/alpaca-lora-7b), our model has lower perplexity. Qualitatively, training on 3 epochs performed the best on perplexity as well as qualitative examples. Comparing our model LoRa to the [Alpaca LoRa](https://huggingface.co/tloen/alpaca-lora-7b), our model has lower perplexity. Qualitatively, training on 3 epochs performed the best on perplexity as well as qualitative examples.
We tried training a full model using the parameters above, but found that during the second epoch the model diverged and samples generated post training were worse than the first epoch. We tried training a full model using the parameters above, but found that during the second epoch the model diverged and samples generated post training were worse than the first epoch.
## GPT-J Training
### Model Training Divergence
We trained multiple [GPT-J models](https://huggingface.co/EleutherAI/gpt-j-6b) with varying success. We found that training the full model lead to diverged post epoch 1. ![](figs/overfit-gpt-j.png)
We release the checkpoint after epoch 1.
Using Atlas, we extracted the embeddings of each point in the dataset and calculated the loss per sequence. We then uploaded [this to Atlas](https://atlas.nomic.ai/map/gpt4all-j-post-epoch-1-embeddings) and noticed that the higher loss items seem to cluster. On further inspection, the highest density clusters seemded to be of prompt/response pairs that asked for creative-like generations such as `Generate a story about ...` ![](figs/clustering_overfit.png)
### GPT4All-J Hyperparameters
We varied learning rate, learning rate schedule, and weight decay following suggestions from the [original GPT-J codebase](https://github.com/kingoflolz/mesh-transformer-jax/blob/master/howto_finetune.md) but found no real performance difference (qualitatively or quantitatively) when varying these parameters.
The final model was trained using the following hyperparameters with a linear warmup followed by constant learning rate:
| Hyperparameter | Value |
|----------------|-------|
| Per Device BS | 32 |
| Global BS | 256 |
| Learning rate | 2e-5 |
| Epochs | 2 |
| Max length | 1024 |
| Weight decay | 0 |
| Warmup Steps | 500 |
The LoRA model was trained using using the following hyperparameters with a linear warmup followed by constant learning rate:
| Hyperparameter | Value |
|----------------|-------|
| Per Device BS | 4 |
| Global BS | 32 |
| Learning rate | 2e-5 |
| Epochs | 2 |
| Max length | 1024 |
| Weight decay | 0 |
| Warmup Steps | 500 |

54
build_map.py Normal file
View File

@ -0,0 +1,54 @@
import numpy as np
from nomic import atlas
import glob
from tqdm import tqdm
from datasets import load_dataset, concatenate_datasets
from sklearn.decomposition import PCA
files = glob.glob("inference/*.jsonl")
print(files)
df = concatenate_datasets([load_dataset("json", data_files=file, split="train") for file in tqdm(files)])
print(len(df))
print(df)
df = df.map(lambda example: {"inputs": [prompt + "\n" + response for prompt, response in zip(example["prompt"], example["response"])]},
batched=True,
num_proc=64)
df = df.map(lambda example: {"trained_on": [int(t) for t in example["is_train"]]},
batched=True,
num_proc=64)
df = df.remove_columns("is_train")
text = df.remove_columns(["labels", "input_ids", "embeddings"])
text_df = [text[i] for i in range(len(text))]
atlas.map_text(text_df, indexed_field="inputs",
name="CHANGE ME!",
colorable_fields=["source", "loss", "trained_on"],
reset_project_if_exists=True,
)
# index is local to train/test split, regenerate
data = df.remove_columns(["labels", "input_ids", "index"])
data = data.add_column("index", list(range(len(data))))
# max embed dim is 2048 for now
# note! this is slow in pyarrow/hf datasets
embeddings = np.array(data["embeddings"])
print("embeddings shape:", embeddings.shape)
embeddings = PCA(n_components=2048).fit_transform(embeddings)
data = data.remove_columns(["embeddings"])
columns = data.to_pandas().to_dict("records")
atlas.map_embeddings(embeddings,
data=columns,
id_field="index",
name="CHANGE ME!",
colorable_fields=["source", "loss", "trained_on"],
build_topic_model=True,
topic_label_field="inputs",
reset_project_if_exists=True,)

View File

@ -64,6 +64,7 @@ for file in glob.glob(os.path.join(prompt_generation_dir, "*.jsonl")):
df = df.dropna(subset=['prompt', 'response']) df = df.dropna(subset=['prompt', 'response'])
df = df[df['prompt'] != ''] df = df[df['prompt'] != '']
df = df[df['response'] != ''] df = df[df['response'] != '']
df = df[df["prompt"].str.len() > 1]
curr_len = len(df) curr_len = len(df)
print(f"Removed {prev_len - curr_len} rows") print(f"Removed {prev_len - curr_len} rows")

View File

@ -0,0 +1,48 @@
{
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"train_micro_batch_size_per_gpu": "auto",
"fp16": {
"enabled": "auto",
"min_loss_scale": 1,
"loss_scale_window": 1000,
"hysteresis": 2,
"initial_scale_power": 32
},
"bf16": {
"enabled": "auto"
},
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 2,
"offload_param": {
"device": "none"
},
"offload_optimizer": {
"device": "none"
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"contiguous_gradients": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": [
0.9,
0.999
],
"eps": 1e-08
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear"
}
}
}

View File

@ -0,0 +1,48 @@
{
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"train_micro_batch_size_per_gpu": "auto",
"fp16": {
"enabled": "auto",
"min_loss_scale": 1,
"loss_scale_window": 1000,
"hysteresis": 2,
"initial_scale_power": 32
},
"bf16": {
"enabled": "auto"
},
"gradient_clipping": 1,
"zero_optimization": {
"stage": 2,
"offload_param": {
"device": "cpu"
},
"offload_optimizer": {
"device": "cpu"
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"contiguous_gradients": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": [
0.9,
0.999
],
"eps": 1e-08
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear"
}
}
}

View File

@ -1,15 +0,0 @@
# model/tokenizer
model_name: # update with llama 7b
tokenizer_name: # update with llama 7b
lora: true
lora_path: "nomic-ai/gpt4all-lora"
max_new_tokens: 512
temperature: 0.001
prompt: |
#this code prints a string reversed
my_string = "hello how are you"
print(len(my_string))
My code above does not work. Can you help me?

View File

@ -1,17 +1,5 @@
# model/tokenizer # model/tokenizer
model_name: # update with llama model name model_name: "zpn/llama-7b"
tokenizer_name: # update with llama model name tokenizer_name: "zpn/llama-7b"
lora: true lora: true
lora_path: "tloen/alpaca-lora-7b" lora_path: "tloen/alpaca-lora-7b"
max_new_tokens: 512
temperature: 0.001
prompt: |
#this code prints a string reversed
my_string = "hello how are you"
print(len(my_string))
My code above does not work. Can you help me?

View File

@ -0,0 +1,4 @@
# model/tokenizer
model_name: "nomic-ai/gpt4all-warmup-lr-epoch_0"
tokenizer_name: "EleutherAI/gpt-j-6b"
lora: false

View File

@ -0,0 +1,5 @@
# model/tokenizer
model_name: "EleutherAI/gpt-j-6b"
tokenizer_name: "EleutherAI/gpt-j-6B"
lora: true
lora_path: "nomic-ai/gpt4all-gptj-lora-epoch_1"

View File

@ -0,0 +1,5 @@
# model/tokenizer
model_name: "zpn/llama-7b"
tokenizer_name: "zpn/llama-7b"
lora: true
lora_path: "nomic-ai/gpt4all-lora"

View File

@ -1,15 +0,0 @@
# model/tokenizer
model_name: # update
tokenizer_name: # update
lora: true
lora_path: # update
max_new_tokens: 512
temperature: 0.001
prompt: |
#this code prints a string reversed
my_string = "hello how are you"
print(len(my_string))
My code above does not work. Can you help me?

View File

@ -1,15 +0,0 @@
# model/tokenizer
model_name: # update
tokenizer_name: # update
lora: true
lora_path: # update
max_new_tokens: 512
temperature: 0.001
prompt: |
#this code prints a string reversed
my_string = "hello how are you"
print(len(my_string))
My code above does not work. Can you help me?

View File

@ -1,6 +1,6 @@
# model/tokenizer # model/tokenizer
model_name: # REPLACE HERE with the base llama model model_name: "zpn/llama-7b"
tokenizer_name: # REPLACE HERE with the llama tokenizer tokenizer_name: "zpn/llama-7b"
lora: true lora: true
lora_path: "nomic-ai/gpt4all-lora" lora_path: "nomic-ai/gpt4all-lora"

View File

@ -1,7 +1,8 @@
# model/tokenizer # model/tokenizer
model_name: # update model_name: "nomic-ai/gpt4all-warmup-lr-epoch_1"
tokenizer_name: # update tokenizer_name: "EleutherAI/gpt-j-6b"
lora_path: "no-lora" lora: false
max_new_tokens: 512 max_new_tokens: 512
temperature: 0.001 temperature: 0.001

View File

@ -0,0 +1,15 @@
# model/tokenizer
model_name: "EleutherAI/gpt-j-6b"
tokenizer_name: "EleutherAI/gpt-j-6b"
lora: true
lora_path: "nomic-ai/gpt4all-gptj-lora-epoch_0"
max_new_tokens: 512
temperature: 0
prompt: |
#this code prints a string reversed
my_string = "hello how are you"
print(len(my_string))
My code above does not work. Can you help me?

View File

@ -0,0 +1,14 @@
# model/tokenizer
model_name: "nomic-ai/gpt4all-warmup-lr-epoch_1"
tokenizer_name: "EleutherAI/gpt-j-6B"
# dataset
streaming: false
num_proc: 64
dataset_path: "nomic-ai/turbo-500k-multi"
max_length: 1024
batch_size: 32
# logging
seed: 42

View File

@ -0,0 +1,33 @@
# model/tokenizer
model_name: "EleutherAI/gpt-j-6B"
tokenizer_name: "EleutherAI/gpt-j-6B"
gradient_checkpointing: true
save_name: # CHANGE
# dataset
streaming: false
num_proc: 64
dataset_path: # CHANGE
max_length: 1024
batch_size: 32
# train dynamics
lr: 2.0e-5
min_lr: 0
weight_decay: 0.0
eval_every: 500
eval_steps: 105
save_every: 500
log_grads_every: 100
output_dir: # CHANGE
checkpoint: null
lora: false
warmup_steps: 500
num_epochs: 2
# logging
wandb: true
wandb_entity: # CHANGE
wandb_project_name: # CHANGE
seed: 42

View File

@ -0,0 +1,33 @@
# model/tokenizer
model_name: "EleutherAI/gpt-j-6b"
tokenizer_name: "EleutherAI/gpt-j-6b"
gradient_checkpointing: false
save_name: # CHANGE
# dataset
streaming: false
num_proc: 64
dataset_path: # CHANGE
max_length: 1024
batch_size: 1
# train dynamics
lr: 2.0e-5
min_lr: 0
weight_decay: 0.0
eval_every: 500
eval_steps: 105
save_every: 500
log_grads_every: 500
output_dir: # CHANGE
checkpoint: null
lora: true
warmup_steps: 500
num_epochs: 2
# logging
wandb: true
wandb_entity: # CHANGE
wandb_project_name: # CHANGE
seed: 42

View File

@ -2,17 +2,19 @@
model_name: # update model_name: # update
tokenizer_name: # update tokenizer_name: # update
gradient_checkpointing: false gradient_checkpointing: false
save_name: "nomic-ai/gpt4all-lora-multi-turn" save_name: # CHANGE
# dataset # dataset
streaming: false streaming: false
num_proc: 64 num_proc: 64
dataset_path: "data_multiturn" dataset_path: "nomic-ai/turbo-500k-multi"
max_length: 1024 max_length: 1024
batch_size: 4 batch_size: 4
# train dynamics # train dynamics
lr: 5.0e-5 lr: 5.0e-5
min_lr: 0
weight_decay: 0.0
eval_every: 2000 eval_every: 2000
eval_steps: 100 eval_steps: 100
save_every: 2000 save_every: 2000

8
create_hostname.sh Normal file
View File

@ -0,0 +1,8 @@
#!/bin/bash
export WORKER_IP=$1
N_GPUS=8
# create dir if doesn't exist
sudo mkdir -p /job
printf "localhost slots=$N_GPUS\n$WORKER_IP slots=$N_GPUS" | sudo tee /job/hostfile
echo /job/hostfile

115
data.py
View File

@ -9,44 +9,49 @@ from transformers import DefaultDataCollator
def tokenize_inputs(config, tokenizer, examples): def tokenize_inputs(config, tokenizer, examples):
max_length = config["max_length"] max_length = config["max_length"]
input_ids = torch.full((len(examples["prompt"]), max_length), tokenizer.pad_token_id)
# ignore bos
newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0, 1:]
out = {"labels": [], "attention_mask": []} # hacky backward compatible
for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])): different_eos = tokenizer.eos_token != "</s>"
input_tokens = tokenizer(prompt, truncation=True, max_length=max_length // 2, return_tensors="pt")["input_ids"].squeeze() out = {"labels": [], "input_ids": []}
input_len = len(input_tokens) for prompt, response in zip(examples["prompt"], examples["response"]):
if different_eos:
if response.count("</s> \n") > 0:
response = response.replace("</s> \n", f"{tokenizer.eos_token} \n")
# plus one since we remove bos from response prompt_len = len(tokenizer(prompt + "\n", return_tensors="pt")["input_ids"][0])
# but we subtract one since we want to add eos token
remaining_tokens = max_length - input_len - len(newline_tokens) + 1
# remove bos
target_tokens = tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:]
input_ids[i, :input_len] = input_tokens # hack if our prompt is super long
# add newline between prompt and response # we need to include some labels so we arbitrarily trunacate at max_length // 2
newline_plus_inputs = input_len + len(newline_tokens) # if the length is too long
input_ids[i, input_len: newline_plus_inputs] = newline_tokens if prompt_len >= max_length // 2:
# if prompt is too long, truncate
# but make sure to truncate to at max 1024 tokens
new_len = min(max_length // 2, len(prompt) // 2)
prompt = prompt[:new_len]
# get new prompt length
prompt_len = tokenizer(prompt + "\n", return_tensors="pt", max_length=max_length // 2, truncation=True).input_ids.ne(tokenizer.pad_token_id).sum().item()
# add target tokens, remove bos assert prompt_len <= max_length // 2, f"prompt length {prompt_len} exceeds max length {max_length}"
input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens
# add eos token; ensure generation stops if inputs aren't truncated
# we don't want long code to stop generating if truncated during training
if newline_plus_inputs + len(target_tokens) < max_length:
input_ids[i, newline_plus_inputs + len(target_tokens)] = tokenizer.eos_token_id
labels = input_ids[i].clone() input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token,
labels[: newline_plus_inputs] = -100 truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()
labels[labels == tokenizer.pad_token_id] = -100
# to debug this, can set all values == -100 to the pad token, then assert that tokenizer.decode(labels, skip_special_tokens=True).strip() == response
attention_mask = input_ids[i].ne(tokenizer.pad_token_id).int() labels = input_tokens.clone()
labels[:prompt_len] = -100
if len(labels) < max_length:
# pad to max_length with -100
labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)])
assert (labels == -100).sum() < len(labels), f"Labels are all -100, something wrong. prompt length {prompt_len} exceeds max length {max_length}"
if (labels == -100).sum() == len(labels) - 1:
print(prompt)
print(response)
raise
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"]
out["labels"].append(labels) out["labels"].append(labels)
out["attention_mask"].append(attention_mask) out["input_ids"].append(input_tokens)
out["input_ids"] = input_ids
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()} out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
@ -110,3 +115,53 @@ def load_data(config, tokenizer):
) )
return train_dataloader, val_dataloader return train_dataloader, val_dataloader
def load_data_for_inference(config, tokenizer):
dataset_path = config["dataset_path"]
if os.path.exists(dataset_path):
# check if path is a directory
if os.path.isdir(dataset_path):
files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
else:
files = [dataset_path]
print(f"Reading files {files}")
dataset = load_dataset("json", data_files=files, split="train")
else:
dataset = load_dataset(dataset_path, split="train")
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
train_dataset, val_dataset = dataset["train"], dataset["test"]
train_dataset = train_dataset.add_column("index", list(range(len(train_dataset))))
# select first N batches that are divisible by batch_size
# gather is a bit annoying (or the way I'm using it) to get uneven batches as it duplicates data
train_dataset = train_dataset.select(range((len(train_dataset) // config["batch_size"]) * config["batch_size"]))
val_dataset = val_dataset.add_column("index", list(range(len(val_dataset))))
val_dataset = val_dataset.select(range((len(val_dataset) // config["batch_size"]) * config["batch_size"]))
if config["streaming"] is False:
kwargs = {"num_proc": config["num_proc"]}
else:
kwargs = {}
# tokenize inputs and return labels and attention mask
train_dataset = train_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele),
batched=True,
**kwargs
)
val_dataset = val_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele),
batched=True,
**kwargs
)
train_dataset = train_dataset.with_format("torch")
val_dataset = val_dataset.with_format("torch")
return train_dataset, val_dataset

View File

@ -6,18 +6,20 @@ from matplotlib import pyplot as plt
plt.figure() plt.figure()
for fpath in glob.glob('./eval_data/*.pkl'): for fpath in glob.glob('./eval_data/*.pkl'):
parts = fpath.split('__') parts = fpath.split('__')
model_name = parts[1].replace('model-', '').replace('.pkl', '') model_name = "-".join(fpath.replace(".pkl", "").split("_")[2:])
lora_name = parts[2].replace('lora-', '').replace('.pkl', '')
with open(fpath, 'rb') as f: with open(fpath, 'rb') as f:
data = pickle.load(f) data = pickle.load(f)
perplexities = data['perplexities'] perplexities = data['perplexities']
perplexities = np.nan_to_num(perplexities, 100) perplexities = np.nan_to_num(perplexities, 100)
perplexities = np.clip(perplexities, 0, 100) perplexities = np.clip(perplexities, 0, 100)
if 'nomic' in fpath: if 'alpaca' not in fpath:
label = 'GPT4all-lora' identifier = model_name = "-".join(fpath.replace(".pkl", "").split("eval__model-")[1:])
label = 'GPT4all-'
label += identifier
else: else:
label = 'alpaca-lora' label = 'alpaca-lora'
plt.hist(perplexities, label=label, alpha=.5) plt.hist(perplexities, label=label, alpha=.5, bins=50)
plt.xlabel('Perplexity') plt.xlabel('Perplexity')
plt.ylabel('Frequency') plt.ylabel('Frequency')

View File

@ -49,28 +49,6 @@ def eval_example(model, tokenizer, example, config):
input = tokenizer(prompt, return_tensors="pt") input = tokenizer(prompt, return_tensors="pt")
input = {k: v.to(model.device) for k, v in input.items()} input = {k: v.to(model.device) for k, v in input.items()}
continuations = []
tokenized_continuations = []
trajectories = []
for i in range(1):
with torch.no_grad():
outputs = model.generate(input_ids=input['input_ids'],
max_new_tokens=config["max_new_tokens"],
min_new_tokens=5,
temperature=config["temperature"],
repetition_penalty=1.0,
do_sample=True)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
y = model(input_ids=outputs)
trajectory = y.hidden_states[0].detach().cpu().numpy()[0]
trajectory = trajectory / np.linalg.norm(trajectory, axis=1, keepdims=True)
trajectory = np.cumsum(trajectory, axis=0) / np.arange(1, trajectory.shape[0]+1).reshape(-1, 1)
trajectories.append(trajectory)
continuations.append(decoded)
tokenized_continuations.append(tokenizer.tokenize(decoded))
#compute the ground truth perplexity #compute the ground truth perplexity
gt_input = tokenizer(gt, return_tensors="pt") gt_input = tokenizer(gt, return_tensors="pt")
gt_input = {k: v.to(model.device) for k, v in gt_input.items()} gt_input = {k: v.to(model.device) for k, v in gt_input.items()}
@ -101,30 +79,23 @@ def eval_example(model, tokenizer, example, config):
print(prompt) print(prompt)
print(80*'-') print(80*'-')
for continuation in continuations:
print(continuation)
print(80*'-')
return ppl, trajectories, continuations, tokenized_continuations
return ppl
def do_eval(config): def do_eval(config):
eval_data = read_jsonl_file('eval_data/user_oriented_instructions.jsonl') eval_data = read_jsonl_file('eval_data/user_oriented_instructions.jsonl')
model, tokenizer = setup_model(config) model, tokenizer = setup_model(config)
all_trajectories = []
all_perplexities = [] all_perplexities = []
all_continuations = []
all_tokenized_continuations = []
for example in tqdm(eval_data): for example in tqdm(eval_data):
gt_perplexity, trajectories, continuations, tokenized_continuations = eval_example(model, tokenizer, example, config) gt_perplexity = eval_example(model, tokenizer, example, config)
all_trajectories.append(trajectories)
all_perplexities.append(gt_perplexity) all_perplexities.append(gt_perplexity)
all_continuations.append(continuations)
with open('eval_data/eval__model-{}__lora-{}.pkl'.format(config['model_name'].replace('/', '_'), config['lora_path'].replace('/', '_')), 'wb') as f:
r = {'trajectories': all_trajectories, name = f"eval_data/eval__model-{config['model_name'].replace('/', '_')}{'__lora-' + config['lora_path'].replace('/', '_') if config['lora'] else ''}.pkl"
'perplexities': all_perplexities,
'continuations': all_continuations, with open(name, 'wb') as f:
'tokenized_continuations': all_tokenized_continuations} r = {'perplexities': all_perplexities}
pickle.dump(r, f) pickle.dump(r, f)

BIN
figs/clustering_overfit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 MiB

BIN
figs/overfit-gpt-j.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 356 KiB

204
inference.py Normal file
View File

@ -0,0 +1,204 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn
from argparse import ArgumentParser
from read import read_config
from accelerate.utils import set_seed
from data import load_data_for_inference
from tqdm import tqdm
from datasets import Dataset
import torch.distributed as dist
from transformers.trainer_pt_utils import nested_numpify
from transformers import DefaultDataCollator
from torch.utils.data import DataLoader, DistributedSampler
import numpy as np
import pyarrow as pa
from pyarrow import compute as pc
def calc_cross_entropy_no_reduction(lm_logits, labels):
# calculate cross entropy across batch dim
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss(reduction='none')
loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels).mean(dim=1)
return loss
def rank0_print(msg):
if dist.get_rank() == 0:
print(msg)
def inference(config):
set_seed(config['seed'])
rank0_print(f"World size: {dist.get_world_size()}")
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
# llama has no pad token, set it to new token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
train_dataset, val_dataset = load_data_for_inference(config, tokenizer)
num_processes = dist.get_world_size()
local_rank = dist.get_rank()
train_sampler = DistributedSampler(train_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank)
train_dataloader = DataLoader(
train_dataset,
collate_fn=DefaultDataCollator(),
batch_size=config["batch_size"],
sampler=train_sampler,
drop_last=True
)
val_sampler = DistributedSampler(val_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank)
val_dataloader = DataLoader(
val_dataset,
collate_fn=DefaultDataCollator(),
batch_size=config["batch_size"],
sampler=val_sampler,
drop_last=True
)
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
model.to(f"cuda:{local_rank}")
with torch.no_grad():
train_outputs = {"loss": [], "embeddings": [], "index": []}
for batch in tqdm(train_dataloader, disable=local_rank != 0):
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
batch["labels"] = batch["labels"].to(f"cuda:{local_rank}")
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True)
loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"])
train_outputs["loss"].extend(loss)
embeddings = outputs.hidden_states[-1]
batch_size = batch["input_ids"].shape[0]
sequence_lengths = []
# since we use mutiturn with multiple <|endoftext|>, we need to find the place where
# <|endoftext|> is repeated
for item in batch["input_ids"]:
indices = torch.where(item == tokenizer.pad_token_id)[0]
found = False
for index in indices:
# case where sequence is less than max length
if torch.all(item[index:] == tokenizer.pad_token_id):
sequence_lengths.append(index)
found = True
break
# case where sequence is >= max length
if not found:
sequence_lengths.append(len(item) - 1)
sequence_lengths = torch.tensor(sequence_lengths)
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
train_outputs["embeddings"].append(pooled_logits)
train_outputs["index"].extend(batch["index"].to(model.device))
torch.cuda.empty_cache()
train_outputs = nested_numpify(train_outputs)
# stack since they're 0-dim arrays
train_outputs["index"] = np.stack(train_outputs["index"])
train_outputs["loss"] = np.stack(train_outputs["loss"])
train_outputs["embeddings"] = np.concatenate(train_outputs["embeddings"])
df_train = Dataset.from_dict(train_outputs)
curr_idx = df_train["index"]
# compute mask in pyarrow since it's super fast
# ty @bmschmidt for showing me this!
table = train_dataset.data
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
filtered_table = table.filter(mask)
# convert from pyarrow to Dataset
filtered_train = Dataset.from_dict(filtered_table.to_pydict())
filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"])
filtered_train = filtered_train.add_column("loss", df_train["loss"])
filtered_train = filtered_train.add_column("is_train", [True] * len(filtered_train))
filtered_train.to_json(f"inference/epoch_2_embeddings_train_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64)
val_outputs = {"loss": [], "embeddings": [], "index": []}
for batch in tqdm(val_dataloader, disable=local_rank != 0):
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
batch["labels"] = batch["labels"].to(f"cuda:{local_rank}")
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True)
loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"])
val_outputs["loss"].extend(loss)
embeddings = outputs.hidden_states[-1]
batch_size = batch["input_ids"].shape[0]
sequence_lengths = []
# since we use mutiturn with multiple <|endoftext|>, we need to find the place where
# <|endoftext|> is repeated
for item in batch["input_ids"]:
indices = torch.where(item == tokenizer.pad_token_id)[0]
found = False
for index in indices:
# case where sequence is less than max length
if torch.all(item[index:] == tokenizer.pad_token_id):
sequence_lengths.append(index)
found = True
break
# case where sequence is >= max length
if not found:
sequence_lengths.append(len(item) - 1)
sequence_lengths = torch.tensor(sequence_lengths)
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
val_outputs["embeddings"].append(pooled_logits)
val_outputs["index"].extend(batch["index"].to(model.device))
torch.cuda.empty_cache()
val_outputs = nested_numpify(val_outputs)
val_outputs["index"] = np.stack(val_outputs["index"])
val_outputs["loss"] = np.stack(val_outputs["loss"])
val_outputs["embeddings"] = np.concatenate(val_outputs["embeddings"])
df_val = Dataset.from_dict(val_outputs)
curr_idx = df_val["index"]
# compute mask in pyarrow since it's super fast
# ty @bmschmidt for showing me this!
table = val_dataset.data
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
filtered_table = table.filter(mask)
# convert from pyarrow to Dataset
filtered_val = Dataset.from_dict(filtered_table.to_pydict())
filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"])
filtered_val = filtered_val.add_column("loss", df_val["loss"])
filtered_val = filtered_val.add_column("is_train", [False] * len(filtered_val))
filtered_val.to_json(f"inference/epoch_2_embeddings_val_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64)
def main():
dist.init_process_group("nccl")
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
args = parser.parse_args()
config = read_config(args.config)
inference(config)
if __name__ == "__main__":
# parse arguments by reading in a config
main()

View File

@ -2,7 +2,7 @@ accelerate
datasets datasets
torchmetrics torchmetrics
evaluate evaluate
transformers transformers>=4.28.0
wandb wandb
pip pip
peft peft
@ -10,3 +10,6 @@ nodelist-inflator
deepspeed deepspeed
sentencepiece sentencepiece
jsonlines jsonlines
nomic
scikit-learn
matplotlib

102
train.py
View File

@ -1,8 +1,7 @@
import os import os
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, LlamaForCausalLM
from transformers.trainer_pt_utils import get_parameter_names
import torch import torch
import torch.nn as nn from torch.optim import AdamW
from argparse import ArgumentParser from argparse import ArgumentParser
from read import read_config from read import read_config
from accelerate import Accelerator from accelerate import Accelerator
@ -11,7 +10,9 @@ from peft import get_peft_model, LoraConfig, TaskType
from data import load_data from data import load_data
from torchmetrics import MeanMetric from torchmetrics import MeanMetric
from tqdm import tqdm from tqdm import tqdm
import wandb
torch.backends.cuda.matmul.allow_tf32 = True
def format_metrics(metrics, split, prefix=""): def format_metrics(metrics, split, prefix=""):
log = f"[{split}]" + prefix log = f"[{split}]" + prefix
@ -20,17 +21,12 @@ def format_metrics(metrics, split, prefix=""):
return log return log
def evaluate(config, model, val_dataloader): def evaluate(model, val_dataloader):
model.eval() model.eval()
val_loss = MeanMetric().to(model.device) val_loss = MeanMetric(nan_strategy="error").to(model.device)
with torch.no_grad(): with torch.no_grad():
for i, batch in enumerate( for batch in tqdm(val_dataloader):
tqdm(val_dataloader),
):
if i == config["eval_steps"]:
break
loss = model(**batch).loss loss = model(**batch).loss
loss_values = accelerator.gather_for_metrics({"loss": loss.detach()}) loss_values = accelerator.gather_for_metrics({"loss": loss.detach()})
@ -46,11 +42,10 @@ def train(accelerator, config):
accelerator.print(config) accelerator.print(config)
accelerator.print(f"Using {accelerator.num_processes} GPUs") accelerator.print(f"Using {accelerator.num_processes} GPUs")
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name']) tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
# llama has no pad token, set it to new token # if no pad token, set it to eos
if tokenizer.pad_token is None: if tokenizer.pad_token is None:
# these tokens are already in the vocab, just not mapped correctly tokenizer.pad_token = tokenizer.eos_token
added_tokens = tokenizer.add_special_tokens({"bos_token": "<s>", "eos_token": "</s>", "pad_token": "<pad>"})
with accelerator.main_process_first(): with accelerator.main_process_first():
@ -61,10 +56,6 @@ def train(accelerator, config):
model = AutoModelForCausalLM.from_pretrained(config["model_name"], model = AutoModelForCausalLM.from_pretrained(config["model_name"],
use_cache=False if checkpoint else True, use_cache=False if checkpoint else True,
trust_remote_code=True) trust_remote_code=True)
if added_tokens > 0:
model.resize_token_embeddings(len(tokenizer))
if checkpoint: if checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
@ -77,7 +68,7 @@ def train(accelerator, config):
model.print_trainable_parameters() model.print_trainable_parameters()
optimizer_cls = ( optimizer_cls = (
torch.optim.AdamW AdamW
if accelerator.state.deepspeed_plugin is None if accelerator.state.deepspeed_plugin is None
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
else DummyOptim else DummyOptim
@ -85,11 +76,35 @@ def train(accelerator, config):
# karpathy doesn't decay embeddding, maybe we should exclude # karpathy doesn't decay embeddding, maybe we should exclude
# https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s # https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s
optimizer = optimizer_cls(model.parameters(), lr=config["lr"]) optimizer = optimizer_cls(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
# scheduler defined in Deepspeed config if accelerator.state.deepspeed_plugin is not None:
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
"gradient_accumulation_steps"
]
# decay to min_lr instead of 0
lr_ratio = config["min_lr"] / config["lr"]
accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}")
total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"]
# instead of decaying to zero, decay to ratio of min_lr / lr
total_num_steps += int(total_num_steps * lr_ratio) + config["warmup_steps"]
accelerator.print(f"Total training steps: {total_num_steps}")
# Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler
if (
accelerator.state.deepspeed_plugin is None
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
):
scheduler = get_scheduler(
name="cosine",
optimizer=optimizer,
num_warmup_steps=config["warmup_steps"] * accelerator.num_processes,
num_training_steps=total_num_steps,
)
else:
scheduler = DummyScheduler( scheduler = DummyScheduler(
optimizer, warmup_num_steps=config["warmup_steps"], optimizer, total_num_steps=config["warmup_steps"], warmup_num_steps=config["warmup_steps"]
) )
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare( model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
@ -108,21 +123,25 @@ def train(accelerator, config):
accelerator.skip_first_batches(train_dataloader, resume_step) accelerator.skip_first_batches(train_dataloader, resume_step)
accelerator.print(f"Resuming from step {resume_step}") accelerator.print(f"Resuming from step {resume_step}")
train_loss = MeanMetric().to(model.device)
if accelerator.state.deepspeed_plugin is not None: # log gradients
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[ if accelerator.is_main_process and config["wandb"]:
"gradient_accumulation_steps" wandb.watch(model, log_freq=config["log_grads_every"], log="all")
]
for epoch in range(config["num_epochs"]): for epoch in range(config["num_epochs"]):
train_loss = MeanMetric(nan_strategy="error").to(model.device)
for step, batch in enumerate(tqdm(train_dataloader)): for step, batch in enumerate(tqdm(train_dataloader)):
model.train() model.train()
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
loss = loss / gradient_accumulation_steps
# gather loss before backprop in case of gradient accumulation
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
train_loss.update(loss_values["loss"])
loss = loss / gradient_accumulation_steps
accelerator.backward(loss) accelerator.backward(loss)
# get gradient norm of all params
# log LR in case something weird happens # log LR in case something weird happens
if step > 0 and step % (config["eval_every"] // 10) == 0: if step > 0 and step % (config["eval_every"] // 10) == 0:
@ -135,14 +154,13 @@ def train(accelerator, config):
scheduler.step() scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
loss_values = accelerator.gather_for_metrics({"loss": loss.detach()})
train_loss.update(loss_values["loss"])
if step > 0 and step % config["save_every"] == 0: if step > 0 and step % config["save_every"] == 0:
accelerator.save_state(f"{config['output_dir']}/step_{step}") curr_step = step + epoch * len(train_dataloader)
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
if step > 0 and step % config["eval_every"] == 0: if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
val_loss = evaluate(config, model, val_dataloader) val_loss = evaluate(model, val_dataloader)
log_train = { log_train = {
"train_loss": train_loss.compute() "train_loss": train_loss.compute()
@ -165,9 +183,20 @@ def train(accelerator, config):
accelerator.print(f"Pushing to HF hub") accelerator.print(f"Pushing to HF hub")
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
try:
if accelerator.is_main_process: if accelerator.is_main_process:
unwrapped_model.push_to_hub(config["save_name"] + "_first_epoch", private=True) unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
except Exception as e:
accelerator.print(e)
accelerator.print(f"Failed to push to hub")
unwrapped_model.save_pretrained(
f"{config['output_dir']}/epoch_{epoch}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
@ -178,9 +207,6 @@ def train(accelerator, config):
state_dict=accelerator.get_state_dict(model), state_dict=accelerator.get_state_dict(model),
) )
if accelerator.is_main_process:
unwrapped_model.push_to_hub(config["save_name"], private=True)
accelerator.end_training() accelerator.end_training()

@ -1 +0,0 @@
Subproject commit cae78c46d658a8e496a815c2ee49b9b178fb9c9a