fix: prompt len for larger

This commit is contained in:
Zach 2023-04-04 22:01:55 +00:00
parent df2d5f7e46
commit 65ec606f21

View File

@ -9,7 +9,6 @@ from transformers import DefaultDataCollator
def tokenize_inputs(config, tokenizer, examples):
max_length = config["max_length"]
input_ids = torch.full((len(examples["prompt"]), max_length), tokenizer.pad_token_id)
# ignore bos
newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0]
if newline_tokens[0] == tokenizer.bos_token_id:
@ -29,6 +28,7 @@ def tokenize_inputs(config, tokenizer, examples):
# we need to include some labels
if prompt_len >= max_length - 1:
prompt = prompt[:len(prompt) // 2]
prompt_len = len(tokenizer(prompt, truncation=True, return_tensors="pt")["input_ids"][0])
input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token,
truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()