Final tweaks

This commit is contained in:
Eric Wang 2023-03-24 12:43:20 -07:00
parent ee19902c00
commit 972fbfbdee

View File

@ -41,7 +41,7 @@ assert (
BASE_MODEL BASE_MODEL
), "Please specify a BASE_MODEL in the script, e.g. 'decapoda-research/llama-7b-hf'" ), "Please specify a BASE_MODEL in the script, e.g. 'decapoda-research/llama-7b-hf'"
TRAIN_ON_INPUTS = True TRAIN_ON_INPUTS = True
GROUP_BY_LENGTH = True GROUP_BY_LENGTH = True # faster, but produces an odd training loss curve
device_map = "auto" device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1)) world_size = int(os.environ.get("WORLD_SIZE", 1))
@ -159,7 +159,7 @@ trainer = transformers.Trainer(
num_train_epochs=EPOCHS, num_train_epochs=EPOCHS,
learning_rate=LEARNING_RATE, learning_rate=LEARNING_RATE,
fp16=True, fp16=True,
logging_steps=1, logging_steps=10,
evaluation_strategy="steps" if VAL_SET_SIZE > 0 else "no", evaluation_strategy="steps" if VAL_SET_SIZE > 0 else "no",
save_strategy="steps", save_strategy="steps",
eval_steps=200 if VAL_SET_SIZE > 0 else None, eval_steps=200 if VAL_SET_SIZE > 0 else None,