mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
fix: prompt len for larger
This commit is contained in:
parent
df2d5f7e46
commit
65ec606f21
2
data.py
2
data.py
@ -9,7 +9,6 @@ from transformers import DefaultDataCollator
|
|||||||
|
|
||||||
def tokenize_inputs(config, tokenizer, examples):
|
def tokenize_inputs(config, tokenizer, examples):
|
||||||
max_length = config["max_length"]
|
max_length = config["max_length"]
|
||||||
input_ids = torch.full((len(examples["prompt"]), max_length), tokenizer.pad_token_id)
|
|
||||||
# ignore bos
|
# ignore bos
|
||||||
newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0]
|
newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0]
|
||||||
if newline_tokens[0] == tokenizer.bos_token_id:
|
if newline_tokens[0] == tokenizer.bos_token_id:
|
||||||
@ -29,6 +28,7 @@ def tokenize_inputs(config, tokenizer, examples):
|
|||||||
# we need to include some labels
|
# we need to include some labels
|
||||||
if prompt_len >= max_length - 1:
|
if prompt_len >= max_length - 1:
|
||||||
prompt = prompt[:len(prompt) // 2]
|
prompt = prompt[:len(prompt) // 2]
|
||||||
|
prompt_len = len(tokenizer(prompt, truncation=True, return_tensors="pt")["input_ids"][0])
|
||||||
|
|
||||||
input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token,
|
input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token,
|
||||||
truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()
|
truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()
|
||||||
|
Loading…
Reference in New Issue
Block a user