fix: pyarrow filter

This commit is contained in:
Zach Nussbaum 2023-04-07 19:04:19 +00:00
parent 7a9f6d1cdc
commit 4b51e6ef37

View File

@ -12,6 +12,8 @@ from transformers.trainer_pt_utils import nested_numpify
from transformers import DefaultDataCollator from transformers import DefaultDataCollator
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
import numpy as np import numpy as np
import pyarrow as pa
from pyarrow import compute as pc
def calc_cross_entropy_no_reduction(lm_logits, labels): def calc_cross_entropy_no_reduction(lm_logits, labels):
@ -116,7 +118,13 @@ def inference(config):
df_train = df_train.sort("index") df_train = df_train.sort("index")
curr_idx = df_train["index"] curr_idx = df_train["index"]
filtered_train = train_dataset.filter(lambda example: example["index"] in curr_idx) # compute mask in pyarrow since it's super fast
# ty @bmschmidt for showing me this!
table = train_dataset.data
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
filtered_table = table.filter(mask)
# convert from pyarrow to Dataset
filtered_train = Dataset.from_dict(filtered_table.to_pydict())
filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"]) filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"])
filtered_train = filtered_train.add_column("loss", df_train["loss"]) filtered_train = filtered_train.add_column("loss", df_train["loss"])
@ -167,7 +175,13 @@ def inference(config):
df_val = df_val.sort("index") df_val = df_val.sort("index")
curr_idx = df_val["index"] curr_idx = df_val["index"]
filtered_val = val_dataset.filter(lambda example: example["index"] in curr_idx) # compute mask in pyarrow since it's super fast
# ty @bmschmidt for showing me this!
table = val_dataset.data
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
filtered_table = table.filter(mask)
# convert from pyarrow to Dataset
filtered_val = Dataset.from_dict(filtered_table.to_pydict())
filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"]) filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"])
filtered_val = filtered_val.add_column("loss", df_val["loss"]) filtered_val = filtered_val.add_column("loss", df_val["loss"])