mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
fix: pyarrow filter
This commit is contained in:
parent
7a9f6d1cdc
commit
4b51e6ef37
18
inference.py
18
inference.py
@ -12,6 +12,8 @@ from transformers.trainer_pt_utils import nested_numpify
|
|||||||
from transformers import DefaultDataCollator
|
from transformers import DefaultDataCollator
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pyarrow as pa
|
||||||
|
from pyarrow import compute as pc
|
||||||
|
|
||||||
|
|
||||||
def calc_cross_entropy_no_reduction(lm_logits, labels):
|
def calc_cross_entropy_no_reduction(lm_logits, labels):
|
||||||
@ -116,7 +118,13 @@ def inference(config):
|
|||||||
df_train = df_train.sort("index")
|
df_train = df_train.sort("index")
|
||||||
curr_idx = df_train["index"]
|
curr_idx = df_train["index"]
|
||||||
|
|
||||||
filtered_train = train_dataset.filter(lambda example: example["index"] in curr_idx)
|
# compute mask in pyarrow since it's super fast
|
||||||
|
# ty @bmschmidt for showing me this!
|
||||||
|
table = train_dataset.data
|
||||||
|
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
|
||||||
|
filtered_table = table.filter(mask)
|
||||||
|
# convert from pyarrow to Dataset
|
||||||
|
filtered_train = Dataset.from_dict(filtered_table.to_pydict())
|
||||||
|
|
||||||
filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"])
|
filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"])
|
||||||
filtered_train = filtered_train.add_column("loss", df_train["loss"])
|
filtered_train = filtered_train.add_column("loss", df_train["loss"])
|
||||||
@ -167,7 +175,13 @@ def inference(config):
|
|||||||
df_val = df_val.sort("index")
|
df_val = df_val.sort("index")
|
||||||
curr_idx = df_val["index"]
|
curr_idx = df_val["index"]
|
||||||
|
|
||||||
filtered_val = val_dataset.filter(lambda example: example["index"] in curr_idx)
|
# compute mask in pyarrow since it's super fast
|
||||||
|
# ty @bmschmidt for showing me this!
|
||||||
|
table = val_dataset.data
|
||||||
|
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
|
||||||
|
filtered_table = table.filter(mask)
|
||||||
|
# convert from pyarrow to Dataset
|
||||||
|
filtered_val = Dataset.from_dict(filtered_table.to_pydict())
|
||||||
|
|
||||||
filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"])
|
filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"])
|
||||||
filtered_val = filtered_val.add_column("loss", df_val["loss"])
|
filtered_val = filtered_val.add_column("loss", df_val["loss"])
|
||||||
|
Loading…
Reference in New Issue
Block a user