Handle training exception for unsupported models

This commit is contained in:
oobabooga 2023-03-29 11:55:34 -03:00 committed by GitHub
parent a6d0373063
commit 58349f44a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,6 +2,7 @@ import json
import sys
import threading
import time
import traceback
from pathlib import Path
import gradio as gr
@ -184,7 +185,13 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
bias="none",
task_type="CAUSAL_LM"
)
try:
lora_model = get_peft_model(shared.model, config)
except:
yield traceback.format_exc()
return
trainer = transformers.Trainer(
model=lora_model,
train_dataset=train_data,