From 0c224cf4f4d9c85ecce7aaf00af0e880c46fb7ac Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 13 Mar 2023 10:32:28 -0300 Subject: [PATCH] Fix GALACTICA (#285) --- modules/text_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index f5d2b8d0..d64481b2 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -123,7 +123,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi original_input_ids = input_ids output = input_ids[0] cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()" - eos_token_ids = [shared.tokenizer.eos_token_id] + eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] if eos_token is not None: eos_token_ids.append(int(encode(eos_token)[0][-1])) stopping_criteria_list = transformers.StoppingCriteriaList()