Unified tokenizer update PR (#146)

* Improve tokenization

The PR changes a few things related to tokenization:

- Sets the padding to the left, which is required if you want to run batched inference with decoder models.

- Pads to the maximum length in each batch ensuring multiple of 8 length (tensor cores like multiple of 8) instead of CUTOFF_LEN. This should make the training faster as less tokens are fed into the model when it is not required (~10% faster in my experiments). To correctly implement this change I need to manually append the eos token (if the input is not truncated) so I have deleted "add_eos" token from the Tokenizer load function. 

- Returns the labels in the tokenize function since some users seem to prefer it this way. This requires using the DataCollatorForSeq2Seq for padding the labels as well as input ids. Behavior of both DataCollators is the same if mlm=False. I can revert to DataCollatorForLanguageModeling if preferred.

* Experimental dynamic batching

* mask out user prompt, again

* Add options

* Remove BASE_MODEL again

* Small optimization

* Final tweaks

---------

Co-authored-by: Iker García-Ferrero <i.garciaferrerosanpelayo@gmail.com>
Co-authored-by: Sebastiaan <751205+SharkWipf@users.noreply.github.com>
This commit is contained in:
Eric J. Wang 2023-03-24 15:46:55 -04:00 committed by GitHub
parent d3760cd84a
commit af30df1999
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,9 +23,9 @@ from peft import (
MICRO_BATCH_SIZE = 4 # this could actually be 5 but i like powers of 2
BATCH_SIZE = 128
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
EPOCHS = 3 # we don't always need 3 tbh
LEARNING_RATE = 3e-4 # the Karpathy constant
CUTOFF_LEN = 256 # 256 accounts for about 96% of the data
EPOCHS = 3
LEARNING_RATE = 3e-4
CUTOFF_LEN = 512
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
@ -40,6 +40,8 @@ BASE_MODEL = None
assert (
BASE_MODEL
), "Please specify a BASE_MODEL in the script, e.g. 'decapoda-research/llama-7b-hf'"
TRAIN_ON_INPUTS = True
GROUP_BY_LENGTH = True # faster, but produces an odd training loss curve
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
@ -53,7 +55,11 @@ model = LlamaForCausalLM.from_pretrained(
load_in_8bit=True,
device_map=device_map,
)
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL, add_eos_token=True)
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
tokenizer.padding_side = "left" # Allow batched inference
model = prepare_model_for_int8_training(model)
@ -66,7 +72,7 @@ config = LoraConfig(
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
data = load_dataset("json", data_files=DATA_PATH)
@ -93,24 +99,43 @@ def generate_prompt(data_point):
{data_point["output"]}"""
def tokenize(prompt):
def tokenize(prompt, add_eos_token=True):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
prompt,
truncation=True,
max_length=CUTOFF_LEN + 1,
padding="max_length",
max_length=CUTOFF_LEN,
padding=False,
return_tensors=None,
)
return {
"input_ids": result["input_ids"][:-1],
"attention_mask": result["attention_mask"][:-1],
}
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < CUTOFF_LEN
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
prompt = generate_prompt(data_point)
return tokenize(prompt)
full_prompt = generate_prompt(data_point)
tokenized_full_prompt = tokenize(full_prompt)
if not TRAIN_ON_INPUTS:
user_prompt = generate_prompt({**data_point, "output": ""})
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][
user_prompt_len:
] # could be sped up, probably
return tokenized_full_prompt
if VAL_SET_SIZE > 0:
@ -134,7 +159,7 @@ trainer = transformers.Trainer(
num_train_epochs=EPOCHS,
learning_rate=LEARNING_RATE,
fp16=True,
logging_steps=20,
logging_steps=10,
evaluation_strategy="steps" if VAL_SET_SIZE > 0 else "no",
save_strategy="steps",
eval_steps=200 if VAL_SET_SIZE > 0 else None,
@ -143,8 +168,11 @@ trainer = transformers.Trainer(
save_total_limit=3,
load_best_model_at_end=True if VAL_SET_SIZE > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=GROUP_BY_LENGTH,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False