mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
feat: train and clean data
This commit is contained in:
parent
2568d94e50
commit
723a50bdf1
71
clean.py
Normal file
71
clean.py
Normal file
@ -0,0 +1,71 @@
|
||||
import numpy as np
|
||||
import glob
|
||||
import os
|
||||
import json
|
||||
import jsonlines
|
||||
import pandas as pd
|
||||
|
||||
|
||||
prompt_generation_dir = "prompts-reponses"
|
||||
for file in glob.glob(os.path.join(prompt_generation_dir, "*.jsonl")):
|
||||
data = []
|
||||
print(file)
|
||||
with open(file) as f:
|
||||
for line in f:
|
||||
try:
|
||||
contents = json.loads(line)
|
||||
data.append(contents)
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
processed = []
|
||||
|
||||
for item in data:
|
||||
if 'source' not in item:
|
||||
item['source'] = 'unspecified'
|
||||
if 'model_settings' in item:
|
||||
item.pop('model_settings', None)
|
||||
|
||||
for key in list(item.keys()):
|
||||
if key not in ['source', 'prompt', 'response']:
|
||||
#print(item[key])
|
||||
item.pop(key, None)
|
||||
|
||||
if isinstance(item['prompt'], dict):
|
||||
if "value" in item["prompt"]:
|
||||
item["prompt"] = item["prompt"]["value"]
|
||||
elif "description" in item["prompt"]:
|
||||
item["prompt"] = item["prompt"]["description"]
|
||||
else:
|
||||
continue
|
||||
|
||||
elif not isinstance(item['prompt'], str):
|
||||
continue
|
||||
|
||||
if isinstance(item['response'], dict):
|
||||
if "value" in item["response"]:
|
||||
item["response"] = item["response"]["value"]
|
||||
elif "description" in item["response"]:
|
||||
item["response"] = item["response"]["description"]
|
||||
else:
|
||||
continue
|
||||
elif not isinstance(item['response'], str):
|
||||
continue
|
||||
|
||||
if item:
|
||||
processed.append(item)
|
||||
|
||||
df = pd.DataFrame(processed)
|
||||
prev_len = len(df)
|
||||
|
||||
# drop empty or null string
|
||||
df = df.dropna(subset=['prompt', 'response'])
|
||||
df = df[df['prompt'] != '']
|
||||
df = df[df['response'] != '']
|
||||
curr_len = len(df)
|
||||
|
||||
print(f"Removed {prev_len - curr_len} rows")
|
||||
|
||||
clean_name = file.split(".jsonl")[0] + "_clean.jsonl"
|
||||
print(f"writing to {clean_name}")
|
||||
df.to_json(clean_name, orient="records", lines=True)
|
48
configs/deepspeed/ds_config.json
Normal file
48
configs/deepspeed/ds_config.json
Normal file
@ -0,0 +1,48 @@
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"min_loss_scale": 1,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"initial_scale_power": 32
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"gradient_clipping": 1,
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_param": {
|
||||
"device": "none"
|
||||
},
|
||||
"offload_optimizer": {
|
||||
"device": "none"
|
||||
},
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": [
|
||||
0.9,
|
||||
0.999
|
||||
],
|
||||
"eps": 1e-08
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 0,
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear"
|
||||
}
|
||||
}
|
||||
}
|
28
configs/train/finetune.yaml
Normal file
28
configs/train/finetune.yaml
Normal file
@ -0,0 +1,28 @@
|
||||
# model/tokenizer
|
||||
model_name: "zpn/llama-7b"
|
||||
tokenizer_name: "zpn/llama-7b"
|
||||
gradient_checkpointing: true
|
||||
|
||||
# dataset
|
||||
streaming: false
|
||||
num_proc: 64
|
||||
dataset_path: "data.jsonl"
|
||||
max_length: 512
|
||||
batch_size: 32
|
||||
|
||||
# train dynamics
|
||||
lr: 5.0e-5
|
||||
eval_every: 2000
|
||||
eval_steps: 100
|
||||
save_every: 2000
|
||||
output_dir: "ckpts/llama-7b"
|
||||
checkpoint: null
|
||||
lora: false
|
||||
warmup_steps: 100
|
||||
|
||||
# logging
|
||||
wandb: false
|
||||
wandb_entity: zanussbaum
|
||||
wandb_project: llama
|
||||
seed: 42
|
||||
|
29
configs/train/finetune_lora.yaml
Normal file
29
configs/train/finetune_lora.yaml
Normal file
@ -0,0 +1,29 @@
|
||||
# model/tokenizer
|
||||
model_name: "zpn/llama-7b"
|
||||
tokenizer_name: "zpn/llama-7b"
|
||||
gradient_checkpointing: false
|
||||
save_name: "zpn/vicuna-lora"
|
||||
|
||||
# dataset
|
||||
streaming: false
|
||||
num_proc: 64
|
||||
dataset_path: "data"
|
||||
max_length: 512
|
||||
batch_size: 8
|
||||
|
||||
# train dynamics
|
||||
lr: 5.0e-5
|
||||
eval_every: 2000
|
||||
eval_steps: 100
|
||||
save_every: 2000
|
||||
output_dir: "ckpts/llama-7b"
|
||||
checkpoint: null
|
||||
lora: true
|
||||
warmup_steps: 100
|
||||
|
||||
# logging
|
||||
wandb: false
|
||||
wandb_entity: zanussbaum
|
||||
wandb_project: llama
|
||||
seed: 42
|
||||
|
108
data.py
Normal file
108
data.py
Normal file
@ -0,0 +1,108 @@
|
||||
import glob
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
import os
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import DefaultDataCollator
|
||||
|
||||
|
||||
|
||||
def tokenize_inputs(config, tokenizer, examples):
|
||||
max_length = config["max_length"]
|
||||
input_ids = torch.full((len(examples["prompt"]), max_length), tokenizer.pad_token_id)
|
||||
# ignore bos
|
||||
newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0, 1:]
|
||||
|
||||
out = {"labels": [], "attention_mask": []}
|
||||
for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])):
|
||||
# HACK to get 512 to work for now
|
||||
input_tokens = tokenizer(prompt, truncation=True, max_length=max_length //2, return_tensors="pt")["input_ids"].squeeze()
|
||||
input_len = len(input_tokens)
|
||||
|
||||
# plus one since we remove bos from response
|
||||
remaining_tokens = max_length - input_len - len(newline_tokens) + 1
|
||||
|
||||
target_tokens = tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:]
|
||||
|
||||
input_ids[i, :input_len] = input_tokens
|
||||
# add newline between prompt and response
|
||||
newline_plus_inputs = input_len + len(newline_tokens)
|
||||
input_ids[i, input_len: newline_plus_inputs] = newline_tokens
|
||||
# add target tokens, remove bos
|
||||
input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens
|
||||
|
||||
labels = input_ids[i].clone()
|
||||
labels[: newline_plus_inputs] = -100
|
||||
labels[labels == tokenizer.pad_token_id] = -100
|
||||
# to debug this, can set all values == -100 to the pad token, then assert that tokenizer.decode(labels, skip_special_tokens=True).strip() == response
|
||||
|
||||
attention_mask = input_ids[i].ne(tokenizer.pad_token_id).int()
|
||||
|
||||
out["labels"].append(labels)
|
||||
out["attention_mask"].append(attention_mask)
|
||||
|
||||
out["input_ids"] = input_ids
|
||||
|
||||
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
def load_data(config, tokenizer):
|
||||
dataset_path = config["dataset_path"]
|
||||
|
||||
if os.path.exists(dataset_path):
|
||||
# check if path is a directory
|
||||
if os.path.isdir(dataset_path):
|
||||
files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
|
||||
else:
|
||||
files = [dataset_path]
|
||||
|
||||
dataset = load_dataset("json", data_files=files, split="train")
|
||||
|
||||
else:
|
||||
dataset = load_dataset(dataset_path)
|
||||
|
||||
|
||||
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
|
||||
|
||||
train_dataset, val_dataset = dataset["train"], dataset["test"]
|
||||
|
||||
if config["streaming"] is False:
|
||||
kwargs = {"num_proc": config["num_proc"]}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
# tokenize inputs and return labels and attention mask
|
||||
train_dataset = train_dataset.map(
|
||||
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
||||
batched=True,
|
||||
remove_columns=["source", "prompt"],
|
||||
**kwargs
|
||||
)
|
||||
val_dataset = val_dataset.map(
|
||||
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
||||
batched=True,
|
||||
remove_columns=["source", "prompt"],
|
||||
**kwargs
|
||||
)
|
||||
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
val_dataset = val_dataset.with_format("torch")
|
||||
|
||||
# create dataloader with default data collator since we already have labels
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
collate_fn=DefaultDataCollator(),
|
||||
batch_size=config["batch_size"],
|
||||
)
|
||||
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
collate_fn=DefaultDataCollator(),
|
||||
batch_size=config["batch_size"],
|
||||
)
|
||||
|
||||
return train_dataloader, val_dataloader
|
10
read.py
Normal file
10
read.py
Normal file
@ -0,0 +1,10 @@
|
||||
import yaml
|
||||
|
||||
|
||||
def read_config(path):
|
||||
# read yaml and return contents
|
||||
with open(path, 'r') as file:
|
||||
try:
|
||||
return yaml.safe_load(file)
|
||||
except yaml.YAMLError as exc:
|
||||
print(exc)
|
187
train.py
Normal file
187
train.py
Normal file
@ -0,0 +1,187 @@
|
||||
import os
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from argparse import ArgumentParser
|
||||
from read import read_config
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DummyScheduler, DummyOptim, set_seed
|
||||
from peft import get_peft_model, LoraConfig, TaskType
|
||||
from data import load_data
|
||||
from torchmetrics import MeanMetric
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def format_metrics(metrics, split, prefix=""):
|
||||
log = f"[{split}]" + prefix
|
||||
log += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()])
|
||||
|
||||
return log
|
||||
|
||||
|
||||
def evaluate(config, model, val_dataloader):
|
||||
model.eval()
|
||||
val_loss = MeanMetric().to(model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
for i, batch in enumerate(
|
||||
tqdm(val_dataloader),
|
||||
):
|
||||
if i == config["eval_steps"]:
|
||||
break
|
||||
|
||||
loss = model(**batch).loss
|
||||
|
||||
loss_values = accelerator.gather_for_metrics({"loss": loss.detach()})
|
||||
|
||||
val_loss.update(loss_values["loss"])
|
||||
|
||||
return val_loss
|
||||
|
||||
|
||||
def train(accelerator, config):
|
||||
set_seed(config['seed'])
|
||||
|
||||
accelerator.print(config)
|
||||
accelerator.print(f"Using {accelerator.num_processes} GPUs")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'])
|
||||
# llama has no pad token, set it to eos
|
||||
if tokenizer.pad_token is None:
|
||||
# these tokens are already in the vocab, just not mapped correctly
|
||||
tokenizer.add_special_tokens({"bos_token": "<s>", "eos_token": "</s>"})
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
with accelerator.main_process_first():
|
||||
train_dataloader, val_dataloader = load_data(config, tokenizer)
|
||||
|
||||
|
||||
checkpoint = config["gradient_checkpointing"]
|
||||
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
||||
use_cache=False if checkpoint else True,
|
||||
trust_remote_code=True)
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
if config["lora"]:
|
||||
peft_config = LoraConfig(
|
||||
# should R be configurable?
|
||||
task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
|
||||
)
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
optimizer_cls = (
|
||||
torch.optim.AdamW
|
||||
if accelerator.state.deepspeed_plugin is None
|
||||
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
else DummyOptim
|
||||
)
|
||||
|
||||
# karpathy doesn't decay embeddding, maybe we should exclude
|
||||
# https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s
|
||||
optimizer = optimizer_cls(model.parameters(), lr=config["lr"])
|
||||
|
||||
# scheduler defined in Deepspeed config
|
||||
scheduler = DummyScheduler(
|
||||
optimizer, warmup_num_steps=config["warmup_steps"],
|
||||
)
|
||||
|
||||
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, val_dataloader, scheduler
|
||||
)
|
||||
|
||||
# setup for saving training states in case preemption
|
||||
accelerator.register_for_checkpointing(scheduler)
|
||||
|
||||
if config["checkpoint"]:
|
||||
accelerator.load_state(config["checkpoint"])
|
||||
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
|
||||
path = os.path.basename(config["train_args"]["resume_from_checkpoint"])
|
||||
training_difference = os.path.splitext(path)[0]
|
||||
resume_step = int(training_difference.replace("step_", ""))
|
||||
accelerator.skip_first_batches(train_dataloader, resume_step)
|
||||
accelerator.print(f"Resuming from step {resume_step}")
|
||||
|
||||
train_loss = MeanMetric().to(model.device)
|
||||
|
||||
for step, batch in enumerate(tqdm(train_dataloader)):
|
||||
model.train()
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
|
||||
# log LR in case something weird happens
|
||||
if step % (config["eval_every"] // 10) == 0:
|
||||
if config["wandb"]:
|
||||
accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=step)
|
||||
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
loss_values = accelerator.gather_for_metrics({"loss": loss.detach()})
|
||||
train_loss.update(loss_values["loss"])
|
||||
|
||||
if step > 0 and step % config["save_every"] == 0:
|
||||
accelerator.save_state(f"{config['output_dir']}/step_{step}")
|
||||
|
||||
if step > 0 and step % config["eval_every"] == 0:
|
||||
val_loss = evaluate(config, model, val_dataloader)
|
||||
|
||||
log_train = {
|
||||
"train_loss": train_loss.compute()
|
||||
}
|
||||
log_val = {
|
||||
"val_loss": val_loss.compute()
|
||||
}
|
||||
|
||||
if config["wandb"]:
|
||||
accelerator.log({**log_train, **log_val}, step=step)
|
||||
|
||||
accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}")
|
||||
accelerator.print(format_metrics(log_train, "train", f" step {step} "))
|
||||
accelerator.print(format_metrics(log_val, "val", f" step {step} "))
|
||||
|
||||
train_loss.reset()
|
||||
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(
|
||||
f"{config['output_dir']}/final",
|
||||
is_main_process=accelerator.is_main_process,
|
||||
save_function=accelerator.save,
|
||||
state_dict=accelerator.get_state_dict(model),
|
||||
)
|
||||
|
||||
unwrapped_model.push_to_hub(config["save_name"], private=True)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# parse arguments by reading in a config
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="config.yaml")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = read_config(args.config)
|
||||
|
||||
if config["wandb"]:
|
||||
accelerator = Accelerator(log_with="wandb")
|
||||
accelerator.init_trackers(
|
||||
project_name=config["wandb_project_name"],
|
||||
config=config,
|
||||
init_kwargs={"wandb": {"entity": config["wandb_entity"]}},
|
||||
)
|
||||
else:
|
||||
accelerator = Accelerator()
|
||||
|
||||
train(accelerator, config=config)
|
Loading…
Reference in New Issue
Block a user