diff --git a/clean.py b/clean.py new file mode 100644 index 00000000..9cf8bf57 --- /dev/null +++ b/clean.py @@ -0,0 +1,71 @@ +import numpy as np +import glob +import os +import json +import jsonlines +import pandas as pd + + +prompt_generation_dir = "prompts-reponses" +for file in glob.glob(os.path.join(prompt_generation_dir, "*.jsonl")): + data = [] + print(file) + with open(file) as f: + for line in f: + try: + contents = json.loads(line) + data.append(contents) + except BaseException: + pass + + processed = [] + + for item in data: + if 'source' not in item: + item['source'] = 'unspecified' + if 'model_settings' in item: + item.pop('model_settings', None) + + for key in list(item.keys()): + if key not in ['source', 'prompt', 'response']: + #print(item[key]) + item.pop(key, None) + + if isinstance(item['prompt'], dict): + if "value" in item["prompt"]: + item["prompt"] = item["prompt"]["value"] + elif "description" in item["prompt"]: + item["prompt"] = item["prompt"]["description"] + else: + continue + + elif not isinstance(item['prompt'], str): + continue + + if isinstance(item['response'], dict): + if "value" in item["response"]: + item["response"] = item["response"]["value"] + elif "description" in item["response"]: + item["response"] = item["response"]["description"] + else: + continue + elif not isinstance(item['response'], str): + continue + + if item: + processed.append(item) + + df = pd.DataFrame(processed) + prev_len = len(df) + + # drop empty or null string + df = df.dropna(subset=['prompt', 'response']) + df = df[df['prompt'] != ''] + df = df[df['response'] != ''] + curr_len = len(df) + + print(f"Removed {prev_len - curr_len} rows") + + clean_name = file.split(".jsonl")[0] + "_clean.jsonl" + print(f"writing to {clean_name}") + df.to_json(clean_name, orient="records", lines=True) \ No newline at end of file diff --git a/configs/deepspeed/ds_config.json b/configs/deepspeed/ds_config.json new file mode 100644 index 00000000..6c04bd4d --- /dev/null +++ b/configs/deepspeed/ds_config.json @@ -0,0 +1,48 @@ +{ + "train_batch_size": "auto", + "gradient_accumulation_steps": "auto", + "train_micro_batch_size_per_gpu": "auto", + "fp16": { + "enabled": "auto", + "min_loss_scale": 1, + "loss_scale_window": 1000, + "hysteresis": 2, + "initial_scale_power": 32 + }, + "bf16": { + "enabled": "auto" + }, + "gradient_clipping": 1, + "zero_optimization": { + "stage": 2, + "offload_param": { + "device": "none" + }, + "offload_optimizer": { + "device": "none" + }, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "contiguous_gradients": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": [ + 0.9, + 0.999 + ], + "eps": 1e-08 + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "warmup_type": "linear" + } + } + } \ No newline at end of file diff --git a/configs/train/finetune.yaml b/configs/train/finetune.yaml new file mode 100644 index 00000000..9724bdbd --- /dev/null +++ b/configs/train/finetune.yaml @@ -0,0 +1,28 @@ +# model/tokenizer +model_name: "zpn/llama-7b" +tokenizer_name: "zpn/llama-7b" +gradient_checkpointing: true + +# dataset +streaming: false +num_proc: 64 +dataset_path: "data.jsonl" +max_length: 512 +batch_size: 32 + +# train dynamics +lr: 5.0e-5 +eval_every: 2000 +eval_steps: 100 +save_every: 2000 +output_dir: "ckpts/llama-7b" +checkpoint: null +lora: false +warmup_steps: 100 + +# logging +wandb: false +wandb_entity: zanussbaum +wandb_project: llama +seed: 42 + diff --git a/configs/train/finetune_lora.yaml b/configs/train/finetune_lora.yaml new file mode 100644 index 00000000..d5fdf92d --- /dev/null +++ b/configs/train/finetune_lora.yaml @@ -0,0 +1,29 @@ +# model/tokenizer +model_name: "zpn/llama-7b" +tokenizer_name: "zpn/llama-7b" +gradient_checkpointing: false +save_name: "zpn/vicuna-lora" + +# dataset +streaming: false +num_proc: 64 +dataset_path: "data" +max_length: 512 +batch_size: 8 + +# train dynamics +lr: 5.0e-5 +eval_every: 2000 +eval_steps: 100 +save_every: 2000 +output_dir: "ckpts/llama-7b" +checkpoint: null +lora: true +warmup_steps: 100 + +# logging +wandb: false +wandb_entity: zanussbaum +wandb_project: llama +seed: 42 + diff --git a/data.py b/data.py new file mode 100644 index 00000000..6db7a3c9 --- /dev/null +++ b/data.py @@ -0,0 +1,108 @@ +import glob +import torch +from datasets import load_dataset +import os +from torch.utils.data import DataLoader +from transformers import DefaultDataCollator + + + +def tokenize_inputs(config, tokenizer, examples): + max_length = config["max_length"] + input_ids = torch.full((len(examples["prompt"]), max_length), tokenizer.pad_token_id) + # ignore bos + newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0, 1:] + + out = {"labels": [], "attention_mask": []} + for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])): + # HACK to get 512 to work for now + input_tokens = tokenizer(prompt, truncation=True, max_length=max_length //2, return_tensors="pt")["input_ids"].squeeze() + input_len = len(input_tokens) + + # plus one since we remove bos from response + remaining_tokens = max_length - input_len - len(newline_tokens) + 1 + + target_tokens = tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:] + + input_ids[i, :input_len] = input_tokens + # add newline between prompt and response + newline_plus_inputs = input_len + len(newline_tokens) + input_ids[i, input_len: newline_plus_inputs] = newline_tokens + # add target tokens, remove bos + input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens + + labels = input_ids[i].clone() + labels[: newline_plus_inputs] = -100 + labels[labels == tokenizer.pad_token_id] = -100 + # to debug this, can set all values == -100 to the pad token, then assert that tokenizer.decode(labels, skip_special_tokens=True).strip() == response + + attention_mask = input_ids[i].ne(tokenizer.pad_token_id).int() + + out["labels"].append(labels) + out["attention_mask"].append(attention_mask) + + out["input_ids"] = input_ids + + out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()} + + return out + + + +def load_data(config, tokenizer): + dataset_path = config["dataset_path"] + + if os.path.exists(dataset_path): + # check if path is a directory + if os.path.isdir(dataset_path): + files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl")) + else: + files = [dataset_path] + + dataset = load_dataset("json", data_files=files, split="train") + + else: + dataset = load_dataset(dataset_path) + + + dataset = dataset.train_test_split(test_size=.05, seed=config["seed"]) + + train_dataset, val_dataset = dataset["train"], dataset["test"] + + if config["streaming"] is False: + kwargs = {"num_proc": config["num_proc"]} + else: + kwargs = {} + + # tokenize inputs and return labels and attention mask + train_dataset = train_dataset.map( + lambda ele: tokenize_inputs(config, tokenizer, ele), + batched=True, + remove_columns=["source", "prompt"], + **kwargs + ) + val_dataset = val_dataset.map( + lambda ele: tokenize_inputs(config, tokenizer, ele), + batched=True, + remove_columns=["source", "prompt"], + **kwargs + ) + + train_dataset = train_dataset.with_format("torch") + val_dataset = val_dataset.with_format("torch") + + # create dataloader with default data collator since we already have labels + + train_dataloader = DataLoader( + train_dataset, + collate_fn=DefaultDataCollator(), + batch_size=config["batch_size"], + ) + + val_dataloader = DataLoader( + val_dataset, + collate_fn=DefaultDataCollator(), + batch_size=config["batch_size"], + ) + + return train_dataloader, val_dataloader diff --git a/read.py b/read.py new file mode 100644 index 00000000..bc6a69f3 --- /dev/null +++ b/read.py @@ -0,0 +1,10 @@ +import yaml + + +def read_config(path): + # read yaml and return contents + with open(path, 'r') as file: + try: + return yaml.safe_load(file) + except yaml.YAMLError as exc: + print(exc) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 00000000..49f1d9f0 --- /dev/null +++ b/train.py @@ -0,0 +1,187 @@ +import os +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.trainer_pt_utils import get_parameter_names +import torch +import torch.nn as nn +from argparse import ArgumentParser +from read import read_config +from accelerate import Accelerator +from accelerate.utils import DummyScheduler, DummyOptim, set_seed +from peft import get_peft_model, LoraConfig, TaskType +from data import load_data +from torchmetrics import MeanMetric +from tqdm import tqdm + + +def format_metrics(metrics, split, prefix=""): + log = f"[{split}]" + prefix + log += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()]) + + return log + + +def evaluate(config, model, val_dataloader): + model.eval() + val_loss = MeanMetric().to(model.device) + + with torch.no_grad(): + for i, batch in enumerate( + tqdm(val_dataloader), + ): + if i == config["eval_steps"]: + break + + loss = model(**batch).loss + + loss_values = accelerator.gather_for_metrics({"loss": loss.detach()}) + + val_loss.update(loss_values["loss"]) + + return val_loss + + +def train(accelerator, config): + set_seed(config['seed']) + + accelerator.print(config) + accelerator.print(f"Using {accelerator.num_processes} GPUs") + + tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name']) + # llama has no pad token, set it to eos + if tokenizer.pad_token is None: + # these tokens are already in the vocab, just not mapped correctly + tokenizer.add_special_tokens({"bos_token": "", "eos_token": ""}) + tokenizer.pad_token = tokenizer.eos_token + + + with accelerator.main_process_first(): + train_dataloader, val_dataloader = load_data(config, tokenizer) + + + checkpoint = config["gradient_checkpointing"] + model = AutoModelForCausalLM.from_pretrained(config["model_name"], + use_cache=False if checkpoint else True, + trust_remote_code=True) + + if checkpoint: + model.gradient_checkpointing_enable() + + if config["lora"]: + peft_config = LoraConfig( + # should R be configurable? + task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1 + ) + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + + optimizer_cls = ( + torch.optim.AdamW + if accelerator.state.deepspeed_plugin is None + or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config + else DummyOptim + ) + + # karpathy doesn't decay embeddding, maybe we should exclude + # https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s + optimizer = optimizer_cls(model.parameters(), lr=config["lr"]) + + # scheduler defined in Deepspeed config + scheduler = DummyScheduler( + optimizer, warmup_num_steps=config["warmup_steps"], + ) + + model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare( + model, optimizer, train_dataloader, val_dataloader, scheduler + ) + + # setup for saving training states in case preemption + accelerator.register_for_checkpointing(scheduler) + + if config["checkpoint"]: + accelerator.load_state(config["checkpoint"]) + accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}") + path = os.path.basename(config["train_args"]["resume_from_checkpoint"]) + training_difference = os.path.splitext(path)[0] + resume_step = int(training_difference.replace("step_", "")) + accelerator.skip_first_batches(train_dataloader, resume_step) + accelerator.print(f"Resuming from step {resume_step}") + + train_loss = MeanMetric().to(model.device) + + for step, batch in enumerate(tqdm(train_dataloader)): + model.train() + outputs = model(**batch) + loss = outputs.loss + + accelerator.backward(loss) + optimizer.step() + + # log LR in case something weird happens + if step % (config["eval_every"] // 10) == 0: + if config["wandb"]: + accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=step) + + scheduler.step() + optimizer.zero_grad() + + loss_values = accelerator.gather_for_metrics({"loss": loss.detach()}) + train_loss.update(loss_values["loss"]) + + if step > 0 and step % config["save_every"] == 0: + accelerator.save_state(f"{config['output_dir']}/step_{step}") + + if step > 0 and step % config["eval_every"] == 0: + val_loss = evaluate(config, model, val_dataloader) + + log_train = { + "train_loss": train_loss.compute() + } + log_val = { + "val_loss": val_loss.compute() + } + + if config["wandb"]: + accelerator.log({**log_train, **log_val}, step=step) + + accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}") + accelerator.print(format_metrics(log_train, "train", f" step {step} ")) + accelerator.print(format_metrics(log_val, "val", f" step {step} ")) + + train_loss.reset() + + + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + f"{config['output_dir']}/final", + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + ) + + unwrapped_model.push_to_hub(config["save_name"], private=True) + + accelerator.end_training() + + + +if __name__ == "__main__": + # parse arguments by reading in a config + parser = ArgumentParser() + parser.add_argument("--config", type=str, default="config.yaml") + + args = parser.parse_args() + + config = read_config(args.config) + + if config["wandb"]: + accelerator = Accelerator(log_with="wandb") + accelerator.init_trackers( + project_name=config["wandb_project_name"], + config=config, + init_kwargs={"wandb": {"entity": config["wandb_entity"]}}, + ) + else: + accelerator = Accelerator() + + train(accelerator, config=config)