fix: eos conditional, watermark

This commit is contained in:
Zach Nussbaum 2023-03-27 16:29:43 +00:00
parent 10db136a88
commit bb28929305

17
data.py
View File

@ -1,6 +1,6 @@
import glob import glob
import torch import torch
from datasets import load_dataset from datasets import load_dataset, concatenate_datasets
import os import os
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import DefaultDataCollator from transformers import DefaultDataCollator
@ -20,7 +20,7 @@ def tokenize_inputs(config, tokenizer, examples):
# plus one since we remove bos from response # plus one since we remove bos from response
# but we subtract one since we want to add eos token # but we subtract one since we want to add eos token
remaining_tokens = max_length - input_len - len(newline_tokens) remaining_tokens = max_length - input_len - len(newline_tokens) + 1
# remove bos # remove bos
target_tokens = tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:] target_tokens = tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:]
@ -31,7 +31,9 @@ def tokenize_inputs(config, tokenizer, examples):
# add target tokens, remove bos # add target tokens, remove bos
input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens
# add eos token, enforce stopping # add eos token, enforce stopping if we don't truncate
# we don't want long code to stop generating if truncated during training
if newline_plus_inputs + len(target_tokens) < max_length:
input_ids[i, newline_plus_inputs + len(target_tokens)] = tokenizer.eos_token_id input_ids[i, newline_plus_inputs + len(target_tokens)] = tokenizer.eos_token_id
labels = input_ids[i].clone() labels = input_ids[i].clone()
@ -51,7 +53,6 @@ def tokenize_inputs(config, tokenizer, examples):
return out return out
def load_data(config, tokenizer): def load_data(config, tokenizer):
dataset_path = config["dataset_path"] dataset_path = config["dataset_path"]
@ -62,16 +63,22 @@ def load_data(config, tokenizer):
else: else:
files = [dataset_path] files = [dataset_path]
print(f"Reading files {files}")
dataset = load_dataset("json", data_files=files, split="train") dataset = load_dataset("json", data_files=files, split="train")
else: else:
dataset = load_dataset(dataset_path) dataset = load_dataset(dataset_path)
uuids = dataset.filter(lambda x: x["source"] == "nomic")
dataset = dataset.filter(lambda x: x["source"] != "nomic")
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"]) dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
train_dataset, val_dataset = dataset["train"], dataset["test"] train_dataset, val_dataset = dataset["train"], dataset["test"]
train_dataset = concatenate_datasets([train_dataset, uuids])
train_dataset = train_dataset.shuffle(seed=config["seed"])
if config["streaming"] is False: if config["streaming"] is False:
kwargs = {"num_proc": config["num_proc"]} kwargs = {"num_proc": config["num_proc"]}
else: else: