fix: cutoff prompt correctly

This commit is contained in:
Zach Nussbaum 2023-03-25 16:48:05 +00:00
parent b6e3ba07c4
commit f51c5c8109

View File

@ -11,7 +11,7 @@ def generate(tokenizer, prompt, model, config):
outputs = model.generate(input_ids=input_ids, max_new_tokens=config["max_new_tokens"], temperature=config["temperature"])
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
return decoded[len(prompt):]