diff --git a/data.py b/data.py index ef84cc2d..0cff50c7 100644 --- a/data.py +++ b/data.py @@ -1,6 +1,6 @@ import glob import torch -from datasets import load_dataset +from datasets import load_dataset, concatenate_datasets import os from torch.utils.data import DataLoader from transformers import DefaultDataCollator @@ -20,7 +20,7 @@ def tokenize_inputs(config, tokenizer, examples): # plus one since we remove bos from response # but we subtract one since we want to add eos token - remaining_tokens = max_length - input_len - len(newline_tokens) + remaining_tokens = max_length - input_len - len(newline_tokens) + 1 # remove bos target_tokens = tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:] @@ -31,8 +31,10 @@ def tokenize_inputs(config, tokenizer, examples): # add target tokens, remove bos input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens - # add eos token, enforce stopping - input_ids[i, newline_plus_inputs + len(target_tokens)] = tokenizer.eos_token_id + # add eos token, enforce stopping if we don't truncate + # we don't want long code to stop generating if truncated during training + if newline_plus_inputs + len(target_tokens) < max_length: + input_ids[i, newline_plus_inputs + len(target_tokens)] = tokenizer.eos_token_id labels = input_ids[i].clone() labels[: newline_plus_inputs] = -100 @@ -51,7 +53,6 @@ def tokenize_inputs(config, tokenizer, examples): return out - def load_data(config, tokenizer): dataset_path = config["dataset_path"] @@ -62,16 +63,22 @@ def load_data(config, tokenizer): else: files = [dataset_path] + print(f"Reading files {files}") + dataset = load_dataset("json", data_files=files, split="train") else: dataset = load_dataset(dataset_path) - + uuids = dataset.filter(lambda x: x["source"] == "nomic") + dataset = dataset.filter(lambda x: x["source"] != "nomic") dataset = dataset.train_test_split(test_size=.05, seed=config["seed"]) train_dataset, val_dataset = dataset["train"], dataset["test"] + train_dataset = concatenate_datasets([train_dataset, uuids]) + train_dataset = train_dataset.shuffle(seed=config["seed"]) + if config["streaming"] is False: kwargs = {"num_proc": config["num_proc"]} else: