mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
fix: pyarrow filter
This commit is contained in:
parent
7a9f6d1cdc
commit
4b51e6ef37
18
inference.py
18
inference.py
@ -12,6 +12,8 @@ from transformers.trainer_pt_utils import nested_numpify
|
||||
from transformers import DefaultDataCollator
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from pyarrow import compute as pc
|
||||
|
||||
|
||||
def calc_cross_entropy_no_reduction(lm_logits, labels):
|
||||
@ -116,7 +118,13 @@ def inference(config):
|
||||
df_train = df_train.sort("index")
|
||||
curr_idx = df_train["index"]
|
||||
|
||||
filtered_train = train_dataset.filter(lambda example: example["index"] in curr_idx)
|
||||
# compute mask in pyarrow since it's super fast
|
||||
# ty @bmschmidt for showing me this!
|
||||
table = train_dataset.data
|
||||
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
|
||||
filtered_table = table.filter(mask)
|
||||
# convert from pyarrow to Dataset
|
||||
filtered_train = Dataset.from_dict(filtered_table.to_pydict())
|
||||
|
||||
filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"])
|
||||
filtered_train = filtered_train.add_column("loss", df_train["loss"])
|
||||
@ -167,7 +175,13 @@ def inference(config):
|
||||
df_val = df_val.sort("index")
|
||||
curr_idx = df_val["index"]
|
||||
|
||||
filtered_val = val_dataset.filter(lambda example: example["index"] in curr_idx)
|
||||
# compute mask in pyarrow since it's super fast
|
||||
# ty @bmschmidt for showing me this!
|
||||
table = val_dataset.data
|
||||
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
|
||||
filtered_table = table.filter(mask)
|
||||
# convert from pyarrow to Dataset
|
||||
filtered_val = Dataset.from_dict(filtered_table.to_pydict())
|
||||
|
||||
filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"])
|
||||
filtered_val = filtered_val.add_column("loss", df_val["loss"])
|
||||
|
Loading…
Reference in New Issue
Block a user