From d87af69a933ed288b3e265a124ff4a39479d32fb Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 27 Mar 2023 17:39:20 +0000 Subject: [PATCH] fix: add tokens --- generate.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/generate.py b/generate.py index af46fd55..cc3ed0f8 100644 --- a/generate.py +++ b/generate.py @@ -19,7 +19,10 @@ def generate(tokenizer, prompt, model, config): def setup_model(config): model = AutoModelForCausalLM.from_pretrained(config["model_name"], device_map="auto", torch_dtype=torch.float16) tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"]) - tokenizer.add_special_tokens({"bos_token": "", "eos_token": ""}) + added_tokens = tokenizer.add_special_tokens({"bos_token": "", "eos_token": "", "pad_token": ""}) + + if added_tokens > 0: + model.resize_token_embeddings(len(tokenizer)) if config["lora"]: model = PeftModelForCausalLM.from_pretrained(model, config["lora_path"], device_map="auto", torch_dtype=torch.float16)