From 32a2bbee4ae9e9bcf26c6b10d0386168a42d9f14 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 2 Aug 2023 11:01:29 -0700 Subject: [PATCH] Implement auto_max_new_tokens for ExLlama --- modules/exllama.py | 6 +++++- modules/loaders.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/modules/exllama.py b/modules/exllama.py index ecfb10a4..00b37b9c 100644 --- a/modules/exllama.py +++ b/modules/exllama.py @@ -94,11 +94,15 @@ class ExllamaModel: # Tokenizing the input ids = self.generator.tokenizer.encode(prompt) ids = ids[:, -get_max_prompt_length(state):] + if state['auto_max_new_tokens']: + max_new_tokens = state['truncation_length'] - ids.shape[-1] + else: + max_new_tokens = state['max_new_tokens'] self.generator.gen_begin_reuse(ids) initial_len = self.generator.sequence[0].shape[0] has_leading_space = False - for i in range(state['max_new_tokens']): + for i in range(max_new_tokens): token = self.generator.gen_single_token() if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'): has_leading_space = True diff --git a/modules/loaders.py b/modules/loaders.py index 838ecc86..68b48204 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -151,6 +151,7 @@ loaders_samplers = { 'repetition_penalty_range', 'seed', 'ban_eos_token', + 'auto_max_new_tokens', }, 'AutoGPTQ': { 'temperature',