mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
feat: wip training log
This commit is contained in:
parent
1280edd744
commit
b170eb9aae
@ -235,3 +235,46 @@ Taking inspiration from [the Alpaca Repo](https://github.com/tatsu-lab/stanford_
|
|||||||
Comparing our model LoRa to the [Alpaca LoRa](https://huggingface.co/tloen/alpaca-lora-7b), our model has lower perplexity. Qualitatively, training on 3 epochs performed the best on perplexity as well as qualitative examples.
|
Comparing our model LoRa to the [Alpaca LoRa](https://huggingface.co/tloen/alpaca-lora-7b), our model has lower perplexity. Qualitatively, training on 3 epochs performed the best on perplexity as well as qualitative examples.
|
||||||
|
|
||||||
We tried training a full model using the parameters above, but found that during the second epoch the model diverged and samples generated post training were worse than the first epoch.
|
We tried training a full model using the parameters above, but found that during the second epoch the model diverged and samples generated post training were worse than the first epoch.
|
||||||
|
|
||||||
|
|
||||||
|
## GPT-J Training
|
||||||
|
|
||||||
|
### Model Training Divergence
|
||||||
|
|
||||||
|
We trained multiple [GPT-J models](https://huggingface.co/EleutherAI/gpt-j-6b) with varying success. We found that training the full model lead to diverged post epoch 1. ![](figs/overfit-gpt-j.png). We release the checkpoint after epoch 1.
|
||||||
|
|
||||||
|
|
||||||
|
Using Atlas, we extracted the embeddings and calculated the per sequence level loss. We then uploaded [this to Atlas](https://atlas.nomic.ai/map/gpt4all-j-post-epoch-1-embeddings) and noticed that the higher loss items seem to cluster. On further inspection, the highest density clusters seemded to be of prompt/response pairs that asked for creative-like generations such as `Generate a story about ...` ![](figs/clustering_overfit.png)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### GPT4All-J Hyperparameters
|
||||||
|
|
||||||
|
We varied learning rate, learning rate schedule, and weight decay following suggestions from the [original GPT-J codebase](https://github.com/kingoflolz/mesh-transformer-jax/blob/master/howto_finetune.md) but found no real performance difference (qualitatively or quantitatively) when varying these parameters.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
The final model was trained using the following hyperparameters with a linear warmup followed by constant learning rate:
|
||||||
|
|
||||||
|
| Hyperparameter | Value |
|
||||||
|
|----------------|-------|
|
||||||
|
| Per Device BS | 32 |
|
||||||
|
| Global BS | 256 |
|
||||||
|
| Learning rate | 2e-5 |
|
||||||
|
| Epochs | 2 |
|
||||||
|
| Max length | 1024 |
|
||||||
|
| Weight decay | 0 |
|
||||||
|
| Warmup Steps | 500 |
|
||||||
|
|
||||||
|
|
||||||
|
The LoRA model was trained using using the following hyperparameters with a linear warmup followed by constant learning rate:
|
||||||
|
|
||||||
|
| Hyperparameter | Value |
|
||||||
|
|----------------|-------|
|
||||||
|
| Per Device BS | 4 |
|
||||||
|
| Global BS | 32 |
|
||||||
|
| Learning rate | 2e-5 |
|
||||||
|
| Epochs | 2 |
|
||||||
|
| Max length | 1024 |
|
||||||
|
| Weight decay | 0 |
|
||||||
|
| Warmup Steps | 500 |
|
||||||
|
BIN
figs/clustering_overfit.png
Normal file
BIN
figs/clustering_overfit.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.3 MiB |
BIN
figs/overfit-gpt-j.png
Normal file
BIN
figs/overfit-gpt-j.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 356 KiB |
Loading…
Reference in New Issue
Block a user