From d910d435cdf16e0040ecea12d3fb116ecee140ab Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 14 Feb 2023 12:06:47 -0300 Subject: [PATCH] Consider the softprompt in the maximum prompt length calculation --- server.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/server.py b/server.py index c079a2aa..fd2bf7bf 100644 --- a/server.py +++ b/server.py @@ -247,8 +247,15 @@ def fix_galactica(s): s = s.replace(r'$$', r'$') return s +def get_max_prompt_length(tokens): + global soft_prompt, soft_prompt_tensor + max_length = 2048-tokens + if soft_prompt: + max_length -= soft_prompt_tensor.shape[1] + return max_length + def encode(prompt, tokens_to_generate=0, add_special_tokens=True): - input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens) + input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens) if args.cpu: return input_ids elif args.deepspeed: @@ -497,7 +504,8 @@ def generate_chat_prompt(text, tokens, name1, name2, context, history_size, impe rows = [f"{context.strip()}\n"] i = len(history['internal'])-1 count = 0 - while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens: + max_length = get_max_prompt_length(tokens) + while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length: rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n") count += 1 if not (history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'): @@ -515,7 +523,7 @@ def generate_chat_prompt(text, tokens, name1, name2, context, history_size, impe rows.append(f"{name1}:") limit = 2 - while len(rows) > limit and len(encode(''.join(rows), tokens)[0]) >= 2048-tokens: + while len(rows) > limit and len(encode(''.join(rows), tokens)[0]) >= max_length: rows.pop(1) rows.pop(1)