diff --git a/data.py b/data.py index 6375584d..8a0dd83f 100644 --- a/data.py +++ b/data.py @@ -142,7 +142,11 @@ def load_data_for_inference(config, tokenizer): train_dataset, val_dataset = dataset["train"], dataset["test"] train_dataset = train_dataset.add_column("index", list(range(len(train_dataset)))) + # select first N batches that are divisible by batch_size + # gather is a bit annoying (or the way I'm using it) to get uneven batches as it duplicates data + train_dataset = train_dataset.select(range((len(train_dataset) // config["batch_size"]) * config["batch_size"])) val_dataset = val_dataset.add_column("index", list(range(len(val_dataset)))) + val_dataset = val_dataset.select(range((len(val_dataset) // config["batch_size"]) * config["batch_size"])) if config["streaming"] is False: kwargs = {"num_proc": config["num_proc"]} diff --git a/inference.py b/inference.py index 3facd91c..1fb620ab 100644 --- a/inference.py +++ b/inference.py @@ -46,20 +46,22 @@ def inference(config): num_processes = dist.get_world_size() local_rank = dist.get_rank() - train_sampler = ShardSampler(train_dataset, config["batch_size"], num_processes=num_processes, process_index=local_rank) + train_sampler = ShardSampler(train_dataset, config["batch_size"], drop_last=True, num_processes=num_processes, process_index=local_rank) train_dataloader = DataLoader( train_dataset, collate_fn=DefaultDataCollator(), batch_size=config["batch_size"], - sampler=train_sampler + sampler=train_sampler, + drop_last=True ) - val_sampler = ShardSampler(val_dataset, config["batch_size"], num_processes=num_processes, process_index=local_rank) + val_sampler = ShardSampler(val_dataset, config["batch_size"], drop_last=True, num_processes=num_processes, process_index=local_rank) val_dataloader = DataLoader( val_dataset, collate_fn=DefaultDataCollator(), batch_size=config["batch_size"], - sampler=val_sampler + sampler=val_sampler, + drop_last=True ) @@ -113,7 +115,6 @@ def inference(config): df_train = Dataset.from_dict(gathered_train) df_train = df_train.sort("index") - train_dataset = train_dataset.add_column("embeddings", df_train["embeddings"]) train_dataset = train_dataset.add_column("loss", df_train["loss"]) train_dataset = train_dataset.add_column("is_train", [True] * len(train_dataset))