From 6c4f449b7a0dc5ba8c4b0fa61dd1f9448b65bf2e Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Wed, 12 Jul 2023 15:18:24 -0400 Subject: [PATCH] fix: update train scripts and configs for other models (#1164) * feat: falcon config * feat: mpt config * chore: gitignore * refactor: step calculation * fix: attention mask + shuffle on epoch end * fix: return tensors * fix: wait for everyone * chore: config * chore: ds config * fix: remove ccols * fix: logging and saving * chore: add einops --- .gitignore | 3 ++ .../configs/deepspeed/ds_config_mpt.json | 49 ++++++++++++++++++ .../configs/deepspeed/ds_config_pythia.json | 48 +++++++++++++++++ .../configs/train/finetune_falcon.yaml | 34 +++++++++++++ .../configs/train/finetune_mpt.yaml | 34 +++++++++++++ .../configs/train/finetune_openllama.yaml | 34 +++++++++++++ gpt4all-training/data.py | 19 ++++--- gpt4all-training/requirements.txt | 2 +- gpt4all-training/train.py | 51 +++++++++++-------- 9 files changed, 245 insertions(+), 29 deletions(-) create mode 100644 gpt4all-training/configs/deepspeed/ds_config_mpt.json create mode 100644 gpt4all-training/configs/deepspeed/ds_config_pythia.json create mode 100644 gpt4all-training/configs/train/finetune_falcon.yaml create mode 100644 gpt4all-training/configs/train/finetune_mpt.yaml create mode 100644 gpt4all-training/configs/train/finetune_openllama.yaml diff --git a/.gitignore b/.gitignore index 67cf225f..1e8a5c36 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +*.arrow +squad_* +*sbert_embedded* *.pkl ckpts* .deepspeed_env diff --git a/gpt4all-training/configs/deepspeed/ds_config_mpt.json b/gpt4all-training/configs/deepspeed/ds_config_mpt.json new file mode 100644 index 00000000..76ed092c --- /dev/null +++ b/gpt4all-training/configs/deepspeed/ds_config_mpt.json @@ -0,0 +1,49 @@ +{ + "train_batch_size": "auto", + "gradient_accumulation_steps": "auto", + "train_micro_batch_size_per_gpu": "auto", + "fp16": { + "enabled": "auto", + "min_loss_scale": 1, + "loss_scale_window": 1000, + "hysteresis": 2, + "initial_scale_power": 32 + }, + "bf16": { + "enabled": "auto" + }, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 1, + "offload_param": { + "device": "none" + }, + "offload_optimizer": { + "device": "none" + }, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "contiguous_gradients": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": [ + 0.9, + 0.999 + ], + "eps": 1e-08 + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "warmup_type": "linear", + "total_num_steps": "auto" + } + } +} \ No newline at end of file diff --git a/gpt4all-training/configs/deepspeed/ds_config_pythia.json b/gpt4all-training/configs/deepspeed/ds_config_pythia.json new file mode 100644 index 00000000..6f9b2961 --- /dev/null +++ b/gpt4all-training/configs/deepspeed/ds_config_pythia.json @@ -0,0 +1,48 @@ +{ + "train_batch_size": "auto", + "gradient_accumulation_steps": "auto", + "train_micro_batch_size_per_gpu": "auto", + "fp16": { + "enabled": "auto", + "min_loss_scale": 1, + "loss_scale_window": 1000, + "hysteresis": 2, + "initial_scale_power": 32 + }, + "bf16": { + "enabled": "auto" + }, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 2, + "offload_param": { + "device": "none" + }, + "offload_optimizer": { + "device": "none" + }, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "contiguous_gradients": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": [ + 0.9, + 0.999 + ], + "eps": 1e-08 + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "warmup_type": "linear" + } + } +} \ No newline at end of file diff --git a/gpt4all-training/configs/train/finetune_falcon.yaml b/gpt4all-training/configs/train/finetune_falcon.yaml new file mode 100644 index 00000000..089708bb --- /dev/null +++ b/gpt4all-training/configs/train/finetune_falcon.yaml @@ -0,0 +1,34 @@ +# model/tokenizer +model_name: "tiiuae/falcon-7b" +tokenizer_name: "tiiuae/falcon-7b" +gradient_checkpointing: true +save_name: "nomic-ai/gpt4all-falcon" + +# dataset +streaming: false +num_proc: 64 +dataset_path: "nomic-ai/gpt4all-j-prompt-generations" +revision: "v1.3-groovy" +max_length: 1024 +batch_size: 32 + +# train dynamics +lr: 2.0e-5 +min_lr: 0 +weight_decay: 0.0 +eval_every: 500 +eval_steps: 105 +save_every: 1000 +log_grads_every: 500 +output_dir: "ckpts/falcon" +checkpoint: "/home/paperspace/gpt4all/ckpts/mpt/step_1000" +lora: false +warmup_steps: 500 +num_epochs: 2 + +# logging +wandb: true +wandb_entity: "gpt4all" +wandb_project_name: "gpt4all" +seed: 42 + diff --git a/gpt4all-training/configs/train/finetune_mpt.yaml b/gpt4all-training/configs/train/finetune_mpt.yaml new file mode 100644 index 00000000..4e1f3638 --- /dev/null +++ b/gpt4all-training/configs/train/finetune_mpt.yaml @@ -0,0 +1,34 @@ +# model/tokenizer +model_name: "mosaicml/mpt-7b" +tokenizer_name: "mosaicml/mpt-7b" +gradient_checkpointing: false +save_name: "nomic-ai/mpt-finetuned-round2" + +# dataset +streaming: false +num_proc: 64 +dataset_path: "nomic-ai/gpt4all-j-prompt-generations" +revision: "v1.3-groovy" +max_length: 1024 +batch_size: 8 + +# train dynamics +lr: 2.0e-5 +min_lr: 0 +weight_decay: 0.0 +eval_every: 500 +eval_steps: 105 +save_every: 1000 +log_grads_every: 500 +output_dir: "ckpts/mpt" +checkpoint: null +lora: false +warmup_steps: 500 +num_epochs: 2 + +# logging +wandb: false +wandb_entity: "gpt4all" +wandb_project_name: "gpt4all" +seed: 42 + diff --git a/gpt4all-training/configs/train/finetune_openllama.yaml b/gpt4all-training/configs/train/finetune_openllama.yaml new file mode 100644 index 00000000..6862f611 --- /dev/null +++ b/gpt4all-training/configs/train/finetune_openllama.yaml @@ -0,0 +1,34 @@ +# model/tokenizer +model_name: "openlm-research/open_llama_7b" +tokenizer_name: "openlm-research/open_llama_7b" +gradient_checkpointing: true +save_name: "nomic-ai/gpt4all-openllama" + +# dataset +streaming: false +num_proc: 64 +dataset_path: "nomic-ai/gpt4all-updated" +revision: null +max_length: 1024 +batch_size: 32 + +# train dynamics +lr: 2.0e-5 +min_lr: 0 +weight_decay: 0.0 +eval_every: 500 +log_every: 10 +save_every: 1000 +log_grads_every: 500 +output_dir: "ckpts/falcon" +checkpoint: null +lora: false +warmup_steps: 500 +num_epochs: 3 + +# logging +wandb: true +wandb_entity: "gpt4all" +wandb_project_name: "gpt4all" +seed: 42 + diff --git a/gpt4all-training/data.py b/gpt4all-training/data.py index 8227de00..f10847de 100644 --- a/gpt4all-training/data.py +++ b/gpt4all-training/data.py @@ -12,7 +12,7 @@ def tokenize_inputs(config, tokenizer, examples): # hacky backward compatible different_eos = tokenizer.eos_token != "" - out = {"labels": [], "input_ids": []} + out = {"labels": [], "input_ids": [], "attention_mask": []} for prompt, response in zip(examples["prompt"], examples["response"]): if different_eos: if response.count(" \n") > 0: @@ -49,9 +49,10 @@ def tokenize_inputs(config, tokenizer, examples): print(response) raise - input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"] + padded = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length, return_tensors="pt") out["labels"].append(labels) - out["input_ids"].append(input_tokens) + out["input_ids"].append(padded["input_ids"]) + out["attention_mask"].append(padded["attention_mask"]) out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()} @@ -72,7 +73,7 @@ def load_data(config, tokenizer): dataset = load_dataset("json", data_files=files, split="train") else: - dataset = load_dataset(dataset_path, split="train") + dataset = load_dataset(dataset_path, split="train", revision=config["revision"] if "revision" in config else None) dataset = dataset.train_test_split(test_size=.05, seed=config["seed"]) @@ -83,19 +84,23 @@ def load_data(config, tokenizer): else: kwargs = {} + cols_to_keep = ["input_ids", "labels", "attention_mask"] # tokenize inputs and return labels and attention mask train_dataset = train_dataset.map( lambda ele: tokenize_inputs(config, tokenizer, ele), batched=True, - remove_columns=["source", "prompt"], **kwargs ) + remove_cols = [col for col in train_dataset.column_names if col not in cols_to_keep] + train_dataset = train_dataset.remove_columns(remove_cols) + val_dataset = val_dataset.map( lambda ele: tokenize_inputs(config, tokenizer, ele), batched=True, - remove_columns=["source", "prompt"], **kwargs ) + remove_cols = [col for col in val_dataset.column_names if col not in cols_to_keep] + val_dataset = val_dataset.remove_columns(remove_cols) train_dataset = train_dataset.with_format("torch") val_dataset = val_dataset.with_format("torch") @@ -106,12 +111,14 @@ def load_data(config, tokenizer): train_dataset, collate_fn=DefaultDataCollator(), batch_size=config["batch_size"], + shuffle=True, ) val_dataloader = DataLoader( val_dataset, collate_fn=DefaultDataCollator(), batch_size=config["batch_size"], + shuffle=True, ) return train_dataloader, val_dataloader diff --git a/gpt4all-training/requirements.txt b/gpt4all-training/requirements.txt index b38ab36c..110977d2 100644 --- a/gpt4all-training/requirements.txt +++ b/gpt4all-training/requirements.txt @@ -1,10 +1,10 @@ accelerate datasets +einops torchmetrics evaluate transformers>=4.28.0 wandb -pip peft nodelist-inflator deepspeed diff --git a/gpt4all-training/train.py b/gpt4all-training/train.py index 69ebce28..829041f6 100644 --- a/gpt4all-training/train.py +++ b/gpt4all-training/train.py @@ -1,5 +1,5 @@ import os -from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, LlamaForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler import torch from torch.optim import AdamW from argparse import ArgumentParser @@ -42,7 +42,7 @@ def train(accelerator, config): accelerator.print(config) accelerator.print(f"Using {accelerator.num_processes} GPUs") - tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length']) + tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'], use_fast=False) # if no pad token, set it to eos if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token @@ -53,6 +53,7 @@ def train(accelerator, config): checkpoint = config["gradient_checkpointing"] + model = AutoModelForCausalLM.from_pretrained(config["model_name"], use_cache=False if checkpoint else True, trust_remote_code=True) @@ -86,7 +87,7 @@ def train(accelerator, config): # decay to min_lr instead of 0 lr_ratio = config["min_lr"] / config["lr"] accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}") - total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"] + total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * (config["num_epochs"]) # instead of decaying to zero, decay to ratio of min_lr / lr total_num_steps += int(total_num_steps * lr_ratio) + config["warmup_steps"] accelerator.print(f"Total training steps: {total_num_steps}") @@ -104,7 +105,7 @@ def train(accelerator, config): ) else: scheduler = DummyScheduler( - optimizer, total_num_steps=config["warmup_steps"], warmup_num_steps=config["warmup_steps"] + optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"] ) model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare( @@ -117,26 +118,34 @@ def train(accelerator, config): if config["checkpoint"]: accelerator.load_state(config["checkpoint"]) accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}") - path = os.path.basename(config["train_args"]["resume_from_checkpoint"]) + path = os.path.basename(config["checkpoint"]) training_difference = os.path.splitext(path)[0] resume_step = int(training_difference.replace("step_", "")) - accelerator.skip_first_batches(train_dataloader, resume_step) + train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) accelerator.print(f"Resuming from step {resume_step}") + else: + resume_step = 0 # log gradients if accelerator.is_main_process and config["wandb"]: wandb.watch(model, log_freq=config["log_grads_every"], log="all") - for epoch in range(config["num_epochs"]): + + accelerator.wait_for_everyone() + + for epoch in range(0, config["num_epochs"]): train_loss = MeanMetric(nan_strategy="error").to(model.device) for step, batch in enumerate(tqdm(train_dataloader)): + curr_step = epoch * len(train_dataloader) + step model.train() outputs = model(**batch) loss = outputs.loss # gather loss before backprop in case of gradient accumulation loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()}) + if config["wandb"]: + accelerator.log({"loss": torch.mean(loss_values["loss"]).item()}, step=curr_step) train_loss.update(loss_values["loss"]) loss = loss / gradient_accumulation_steps @@ -144,9 +153,8 @@ def train(accelerator, config): # get gradient norm of all params # log LR in case something weird happens - if step > 0 and step % (config["eval_every"] // 10) == 0: + if step > 0 and step % (config["log_lr_every"]) == 0: if config["wandb"]: - curr_step = step + epoch * len(train_dataloader) accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step) if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: @@ -156,7 +164,6 @@ def train(accelerator, config): if step > 0 and step % config["save_every"] == 0: - curr_step = step + epoch * len(train_dataloader) accelerator.save_state(f"{config['output_dir']}/step_{curr_step}") if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1): @@ -170,7 +177,6 @@ def train(accelerator, config): } if config["wandb"]: - curr_step = step + epoch * len(train_dataloader) accelerator.log({**log_train, **log_val}, step=curr_step) accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}") @@ -181,8 +187,14 @@ def train(accelerator, config): accelerator.print(f"Epoch {epoch} finished") accelerator.print(f"Pushing to HF hub") - accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) + + unwrapped_model.save_pretrained( + f"{config['output_dir']}/epoch_{epoch}", + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + ) try: if accelerator.is_main_process: unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True) @@ -191,21 +203,16 @@ def train(accelerator, config): accelerator.print(e) accelerator.print(f"Failed to push to hub") + + if config["num_epochs"] > 1: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.save_pretrained( - f"{config['output_dir']}/epoch_{epoch}", + f"{config['output_dir']}/final", is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model), ) - - accelerator.wait_for_everyone() - unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained( - f"{config['output_dir']}/final", - is_main_process=accelerator.is_main_process, - save_function=accelerator.save, - state_dict=accelerator.get_state_dict(model), - ) accelerator.end_training()