fix: add eos

This commit is contained in:
Zach Nussbaum 2023-03-26 17:45:31 +00:00
parent 2daecd6066
commit eac7734cbf

11
data.py
View File

@ -15,21 +15,24 @@ def tokenize_inputs(config, tokenizer, examples):
out = {"labels": [], "attention_mask": []}
for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])):
# HACK to get 512 to work for now
input_tokens = tokenizer(prompt, truncation=True, max_length=max_length //2, return_tensors="pt")["input_ids"].squeeze()
input_tokens = tokenizer(prompt, truncation=True, max_length=max_length // 2, return_tensors="pt")["input_ids"].squeeze()
input_len = len(input_tokens)
# plus one since we remove bos from response
remaining_tokens = max_length - input_len - len(newline_tokens) + 1
# but we subtract one since we want to add eos token
remaining_tokens = max_length - input_len - len(newline_tokens)
# remove bos
target_tokens = tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:]
input_ids[i, :input_len] = input_tokens
# add newline between prompt and response
newline_plus_inputs = input_len + len(newline_tokens)
input_ids[i, input_len: newline_plus_inputs] = newline_tokens
# add target tokens, remove bos
input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens
# add eos token, enforce stopping
input_ids[i, newline_plus_inputs + len(target_tokens)] = tokenizer.eos_token_id
labels = input_ids[i].clone()
labels[: newline_plus_inputs] = -100