diff --git a/configs/train/finetune_gptj.yaml b/configs/train/finetune_gptj.yaml new file mode 100644 index 00000000..31c25050 --- /dev/null +++ b/configs/train/finetune_gptj.yaml @@ -0,0 +1,33 @@ +# model/tokenizer +model_name: "EleutherAI/gpt-j-6B" +tokenizer_name: "EleutherAI/gpt-j-6B" +gradient_checkpointing: true +save_name: "nomic-ai/gpt4all-gptj-multiturn-lr-aggressive" + +# dataset +streaming: false +num_proc: 64 +dataset_path: "data_multiplus" +max_length: 1024 +batch_size: 8 + +# train dynamics +lr: 2.0e-5 +min_lr: 0 +weight_decay: 0.0 +eval_every: 200 +eval_steps: 105 +save_every: 400 +log_grads_every: 200 +output_dir: "ckpts/gpt4all-gptj-full-multiturn-lr-aggreive" +checkpoint: null +lora: false +warmup_steps: 500 +num_epochs: 4 + +# logging +wandb: true +wandb_entity: vicuna +wandb_project_name: vicuna +seed: 42 + diff --git a/data.py b/data.py index 0e356f7d..0de26cfa 100644 --- a/data.py +++ b/data.py @@ -11,42 +11,38 @@ def tokenize_inputs(config, tokenizer, examples): max_length = config["max_length"] input_ids = torch.full((len(examples["prompt"]), max_length), tokenizer.pad_token_id) # ignore bos - newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0, 1:] + newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0] + if newline_tokens[0] == tokenizer.bos_token_id: + newline_tokens = newline_tokens[1:] - out = {"labels": [], "attention_mask": []} - for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])): - input_tokens = tokenizer(prompt, truncation=True, max_length=max_length // 2, return_tensors="pt")["input_ids"].squeeze() - input_len = len(input_tokens) + # hacky backward compatible + different_eos = tokenizer.eos_token != "" + out = {"labels": [], "input_ids": []} + for prompt, response in zip(examples["prompt"], examples["response"]): + if different_eos: + if response.count("") > 0: + response = response.replace("", tokenizer.eos_token) - # plus one since we remove bos from response - # but we subtract one since we want to add eos token - remaining_tokens = max_length - input_len - len(newline_tokens) + 1 - # remove bos - target_tokens = tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:] + prompt_len = len(tokenizer(prompt, truncation=True, return_tensors="pt")["input_ids"][0]) - input_ids[i, :input_len] = input_tokens - # add newline between prompt and response - newline_plus_inputs = input_len + len(newline_tokens) - input_ids[i, input_len: newline_plus_inputs] = newline_tokens + # hack if our prompt is super long + # we need to include some labels + if prompt_len >= max_length - 1: + prompt = prompt[:len(prompt) // 2] - # add target tokens, remove bos - input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens - # add eos token, enforce stopping if we don't truncate - # we don't want long code to stop generating if truncated during training - if newline_plus_inputs + len(target_tokens) < max_length: - input_ids[i, newline_plus_inputs + len(target_tokens)] = tokenizer.eos_token_id + input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token, + truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze() - labels = input_ids[i].clone() - labels[: newline_plus_inputs] = -100 - labels[labels == tokenizer.pad_token_id] = -100 - # to debug this, can set all values == -100 to the pad token, then assert that tokenizer.decode(labels, skip_special_tokens=True).strip() == response - attention_mask = input_ids[i].ne(tokenizer.pad_token_id).int() + labels = input_tokens.clone() + labels[:prompt_len + len(newline_tokens)] = -100 + if len(labels) < max_length: + # pad to max_length with -100 + labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)]) + input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"] out["labels"].append(labels) - out["attention_mask"].append(attention_mask) - - out["input_ids"] = input_ids + out["input_ids"].append(input_tokens) out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}