initial progress tracker in UI

This commit is contained in:
Alex "mcmonkey" Goodwin 2023-03-27 10:25:08 -07:00
parent c07bcd0850
commit 8fc723fc95

View File

@ -1,4 +1,4 @@
import sys, torch, json
import sys, torch, json, threading, time
from pathlib import Path
import gradio as gr
from datasets import load_dataset
@ -6,6 +6,9 @@ import transformers
from modules import ui, shared
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict
CURRENT_STEPS = 0
MAX_STEPS = 0
def get_json_dataset(path: str):
def get_set():
return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path(path).glob('*.json'))), key=str.lower)
@ -40,6 +43,12 @@ def create_train_interface():
output = gr.Markdown(value="(...)")
startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output])
class Callbacks(transformers.TrainerCallback):
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
global CURRENT_STEPS, MAX_STEPS
CURRENT_STEPS = state.global_step
MAX_STEPS = state.max_steps
def cleanPath(basePath: str, path: str):
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
@ -50,8 +59,11 @@ def cleanPath(basePath: str, path: str):
return f'{Path(basePath).absolute()}/{path}'
def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, learningRate: float, loraRank: int, loraAlpha: int, loraDropout: float, cutoffLen: int, dataset: str, evalDataset: str, format: str):
global CURRENT_STEPS, MAX_STEPS
CURRENT_STEPS = 0
MAX_STEPS = 0
yield "Prepping..."
# Input validation / processing
# == Input validation / processing ==
# TODO: --lora-dir PR once pulled will need to be applied here
loraName = f"loras/{cleanPath(None, loraName)}"
if dataset is None:
@ -62,7 +74,7 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
actualLR = float(learningRate)
shared.tokenizer.pad_token = 0
shared.tokenizer.padding_side = "left"
# Prep the dataset, format, etc
# == Prep the dataset, format, etc ==
with open(cleanPath('training/formats', f'{format}.json'), 'r') as formatFile:
formatData: dict[str, str] = json.load(formatFile)
def tokenize(prompt):
@ -89,7 +101,7 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
else:
evalData = load_dataset("json", data_files=cleanPath('training/datasets', f'{evalDataset}.json'))
evalData = evalData['train'].shuffle().map(generate_and_tokenize_prompt)
# Start prepping the model itself
# == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
print("Getting model ready...")
prepare_model_for_int8_training(shared.model)
@ -128,6 +140,7 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
ddp_find_unused_parameters=None
),
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
callbacks=list([Callbacks()])
)
loraModel.config.use_cache = False
old_state_dict = loraModel.state_dict
@ -136,12 +149,31 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
).__get__(loraModel, type(loraModel))
if torch.__version__ >= "2" and sys.platform != "win32":
loraModel = torch.compile(loraModel)
# Actually start and run and save at the end
# == Main run and monitor loop ==
# TODO: save/load checkpoints to resume from?
print("Starting training...")
yield "Running..."
trainer.train()
yield "Starting..."
def threadedRun():
trainer.train()
thread = threading.Thread(target=threadedRun)
thread.start()
lastStep = 0
startTime = time.perf_counter()
while thread.is_alive():
time.sleep(0.5)
if CURRENT_STEPS != lastStep:
lastStep = CURRENT_STEPS
timeElapsed = time.perf_counter() - startTime
if timeElapsed <= 0:
timerInfo = ""
else:
its = CURRENT_STEPS / timeElapsed
if its > 1:
timerInfo = f"`{its:.2f}` it/s"
else:
timerInfo = f"`{1.0/its:.2f}` s/it"
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.1f}` seconds"
print("Training complete, saving...")
loraModel.save_pretrained(loraName)
print("Training complete!")
yield f"Done! Lora saved to `{loraName}`"
yield f"Done! LoRA saved to `{loraName}`"