mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
fix: drop uneven batch size
This commit is contained in:
parent
985da51fbc
commit
0bd6acb4dd
4
data.py
4
data.py
@ -142,7 +142,11 @@ def load_data_for_inference(config, tokenizer):
|
|||||||
train_dataset, val_dataset = dataset["train"], dataset["test"]
|
train_dataset, val_dataset = dataset["train"], dataset["test"]
|
||||||
|
|
||||||
train_dataset = train_dataset.add_column("index", list(range(len(train_dataset))))
|
train_dataset = train_dataset.add_column("index", list(range(len(train_dataset))))
|
||||||
|
# select first N batches that are divisible by batch_size
|
||||||
|
# gather is a bit annoying (or the way I'm using it) to get uneven batches as it duplicates data
|
||||||
|
train_dataset = train_dataset.select(range((len(train_dataset) // config["batch_size"]) * config["batch_size"]))
|
||||||
val_dataset = val_dataset.add_column("index", list(range(len(val_dataset))))
|
val_dataset = val_dataset.add_column("index", list(range(len(val_dataset))))
|
||||||
|
val_dataset = val_dataset.select(range((len(val_dataset) // config["batch_size"]) * config["batch_size"]))
|
||||||
|
|
||||||
if config["streaming"] is False:
|
if config["streaming"] is False:
|
||||||
kwargs = {"num_proc": config["num_proc"]}
|
kwargs = {"num_proc": config["num_proc"]}
|
||||||
|
11
inference.py
11
inference.py
@ -46,20 +46,22 @@ def inference(config):
|
|||||||
num_processes = dist.get_world_size()
|
num_processes = dist.get_world_size()
|
||||||
local_rank = dist.get_rank()
|
local_rank = dist.get_rank()
|
||||||
|
|
||||||
train_sampler = ShardSampler(train_dataset, config["batch_size"], num_processes=num_processes, process_index=local_rank)
|
train_sampler = ShardSampler(train_dataset, config["batch_size"], drop_last=True, num_processes=num_processes, process_index=local_rank)
|
||||||
train_dataloader = DataLoader(
|
train_dataloader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
collate_fn=DefaultDataCollator(),
|
collate_fn=DefaultDataCollator(),
|
||||||
batch_size=config["batch_size"],
|
batch_size=config["batch_size"],
|
||||||
sampler=train_sampler
|
sampler=train_sampler,
|
||||||
|
drop_last=True
|
||||||
)
|
)
|
||||||
|
|
||||||
val_sampler = ShardSampler(val_dataset, config["batch_size"], num_processes=num_processes, process_index=local_rank)
|
val_sampler = ShardSampler(val_dataset, config["batch_size"], drop_last=True, num_processes=num_processes, process_index=local_rank)
|
||||||
val_dataloader = DataLoader(
|
val_dataloader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
collate_fn=DefaultDataCollator(),
|
collate_fn=DefaultDataCollator(),
|
||||||
batch_size=config["batch_size"],
|
batch_size=config["batch_size"],
|
||||||
sampler=val_sampler
|
sampler=val_sampler,
|
||||||
|
drop_last=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -113,7 +115,6 @@ def inference(config):
|
|||||||
|
|
||||||
df_train = Dataset.from_dict(gathered_train)
|
df_train = Dataset.from_dict(gathered_train)
|
||||||
df_train = df_train.sort("index")
|
df_train = df_train.sort("index")
|
||||||
|
|
||||||
train_dataset = train_dataset.add_column("embeddings", df_train["embeddings"])
|
train_dataset = train_dataset.add_column("embeddings", df_train["embeddings"])
|
||||||
train_dataset = train_dataset.add_column("loss", df_train["loss"])
|
train_dataset = train_dataset.add_column("loss", df_train["loss"])
|
||||||
train_dataset = train_dataset.add_column("is_train", [True] * len(train_dataset))
|
train_dataset = train_dataset.add_column("is_train", [True] * len(train_dataset))
|
||||||
|
Loading…
Reference in New Issue
Block a user