feat: wip training log

This commit is contained in:
Zach Nussbaum 2023-04-13 18:41:39 +00:00
parent 1280edd744
commit b170eb9aae
3 changed files with 43 additions and 0 deletions

View File

@ -235,3 +235,46 @@ Taking inspiration from [the Alpaca Repo](https://github.com/tatsu-lab/stanford_
Comparing our model LoRa to the [Alpaca LoRa](https://huggingface.co/tloen/alpaca-lora-7b), our model has lower perplexity. Qualitatively, training on 3 epochs performed the best on perplexity as well as qualitative examples.
We tried training a full model using the parameters above, but found that during the second epoch the model diverged and samples generated post training were worse than the first epoch.
## GPT-J Training
### Model Training Divergence
We trained multiple [GPT-J models](https://huggingface.co/EleutherAI/gpt-j-6b) with varying success. We found that training the full model lead to diverged post epoch 1. ![](figs/overfit-gpt-j.png). We release the checkpoint after epoch 1.
Using Atlas, we extracted the embeddings and calculated the per sequence level loss. We then uploaded [this to Atlas](https://atlas.nomic.ai/map/gpt4all-j-post-epoch-1-embeddings) and noticed that the higher loss items seem to cluster. On further inspection, the highest density clusters seemded to be of prompt/response pairs that asked for creative-like generations such as `Generate a story about ...` ![](figs/clustering_overfit.png)
### GPT4All-J Hyperparameters
We varied learning rate, learning rate schedule, and weight decay following suggestions from the [original GPT-J codebase](https://github.com/kingoflolz/mesh-transformer-jax/blob/master/howto_finetune.md) but found no real performance difference (qualitatively or quantitatively) when varying these parameters.
The final model was trained using the following hyperparameters with a linear warmup followed by constant learning rate:
| Hyperparameter | Value |
|----------------|-------|
| Per Device BS | 32 |
| Global BS | 256 |
| Learning rate | 2e-5 |
| Epochs | 2 |
| Max length | 1024 |
| Weight decay | 0 |
| Warmup Steps | 500 |
The LoRA model was trained using using the following hyperparameters with a linear warmup followed by constant learning rate:
| Hyperparameter | Value |
|----------------|-------|
| Per Device BS | 4 |
| Global BS | 32 |
| Learning rate | 2e-5 |
| Epochs | 2 |
| Max length | 1024 |
| Weight decay | 0 |
| Warmup Steps | 500 |

BIN
figs/clustering_overfit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 MiB

BIN
figs/overfit-gpt-j.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 356 KiB