make 'model' variables less ambiguous

This commit is contained in:
Alex "mcmonkey" Goodwin 2023-03-25 12:57:36 -07:00
parent 8da237223e
commit f1ba2196b1

View File

@ -59,15 +59,13 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
return "**Missing format choice input, cannot continue.**" return "**Missing format choice input, cannot continue.**"
gradientAccumulationSteps = batchSize // microBatchSize gradientAccumulationSteps = batchSize // microBatchSize
actualLR = float(learningRate) actualLR = float(learningRate)
model = shared.model shared.tokenizer.pad_token = 0
tokenizer = shared.tokenizer shared.tokenizer.padding_side = "left"
tokenizer.pad_token = 0
tokenizer.padding_side = "left"
# Prep the dataset, format, etc # Prep the dataset, format, etc
with open(cleanPath('training/formats', f'{format}.json'), 'r') as formatFile: with open(cleanPath('training/formats', f'{format}.json'), 'r') as formatFile:
formatData: dict[str, str] = json.load(formatFile) formatData: dict[str, str] = json.load(formatFile)
def tokenize(prompt): def tokenize(prompt):
result = tokenizer(prompt, truncation=True, max_length=cutoffLen + 1, padding="max_length") result = shared.tokenizer(prompt, truncation=True, max_length=cutoffLen + 1, padding="max_length")
return { return {
"input_ids": result["input_ids"][:-1], "input_ids": result["input_ids"][:-1],
"attention_mask": result["attention_mask"][:-1], "attention_mask": result["attention_mask"][:-1],
@ -90,8 +88,8 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
evalData = load_dataset("json", data_files=cleanPath('training/datasets', f'{evalDataset}.json')) evalData = load_dataset("json", data_files=cleanPath('training/datasets', f'{evalDataset}.json'))
evalData = evalData['train'].shuffle().map(generate_and_tokenize_prompt) evalData = evalData['train'].shuffle().map(generate_and_tokenize_prompt)
# Start prepping the model itself # Start prepping the model itself
if not hasattr(model, 'lm_head') or hasattr(model.lm_head, 'weight'): if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
model = prepare_model_for_int8_training(model) prepare_model_for_int8_training(shared.model)
config = LoraConfig( config = LoraConfig(
r=loraRank, r=loraRank,
lora_alpha=loraAlpha, lora_alpha=loraAlpha,
@ -101,9 +99,9 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
bias="none", bias="none",
task_type="CAUSAL_LM" task_type="CAUSAL_LM"
) )
model = get_peft_model(model, config) loraModel = get_peft_model(shared.model, config)
trainer = transformers.Trainer( trainer = transformers.Trainer(
model=model, model=loraModel,
train_dataset=train_data, train_dataset=train_data,
eval_dataset=evalData, eval_dataset=evalData,
args=transformers.TrainingArguments( args=transformers.TrainingArguments(
@ -125,16 +123,16 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
# TODO: Enable multi-device support # TODO: Enable multi-device support
ddp_find_unused_parameters=None, ddp_find_unused_parameters=None,
), ),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
) )
model.config.use_cache = False loraModel.config.use_cache = False
old_state_dict = model.state_dict old_state_dict = loraModel.state_dict
model.state_dict = ( loraModel.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model)) ).__get__(loraModel, type(loraModel))
if torch.__version__ >= "2" and sys.platform != "win32": if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model) loraModel = torch.compile(loraModel)
# Actually start and run and save at the end # Actually start and run and save at the end
trainer.train() trainer.train()
model.save_pretrained(loraName) loraModel.save_pretrained(loraName)
return "Done!" return "Done!"