fix: add tokens

This commit is contained in:
Zach Nussbaum 2023-03-27 17:39:20 +00:00
parent eac7734cbf
commit d87af69a93

View File

@ -19,7 +19,10 @@ def generate(tokenizer, prompt, model, config):
def setup_model(config):
model = AutoModelForCausalLM.from_pretrained(config["model_name"], device_map="auto", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
tokenizer.add_special_tokens({"bos_token": "<s>", "eos_token": "</s>"})
added_tokens = tokenizer.add_special_tokens({"bos_token": "<s>", "eos_token": "</s>", "pad_token": "<pad>"})
if added_tokens > 0:
model.resize_token_embeddings(len(tokenizer))
if config["lora"]:
model = PeftModelForCausalLM.from_pretrained(model, config["lora_path"], device_map="auto", torch_dtype=torch.float16)