mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Use gen_begin_reuse in exllama
This commit is contained in:
parent
239b11c94b
commit
766c760cd7
@ -41,45 +41,49 @@ class ExllamaModel:
|
||||
model = ExLlama(config)
|
||||
tokenizer = ExLlamaTokenizer(str(tokenizer_model_path))
|
||||
cache = ExLlamaCache(model)
|
||||
generator = ExLlamaGenerator(model, tokenizer, cache)
|
||||
|
||||
result = self()
|
||||
result.config = config
|
||||
result.model = model
|
||||
result.cache = cache
|
||||
result.tokenizer = tokenizer
|
||||
self.generator = generator
|
||||
return result, result
|
||||
|
||||
def generate(self, prompt, state, callback=None):
|
||||
generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
|
||||
generator.settings.temperature = state['temperature']
|
||||
generator.settings.top_p = state['top_p']
|
||||
generator.settings.top_k = state['top_k']
|
||||
generator.settings.typical = state['typical_p']
|
||||
generator.settings.token_repetition_penalty_max = state['repetition_penalty']
|
||||
self.generator.settings.temperature = state['temperature']
|
||||
self.generator.settings.top_p = state['top_p']
|
||||
self.generator.settings.top_k = state['top_k']
|
||||
self.generator.settings.typical = state['typical_p']
|
||||
self.generator.settings.token_repetition_penalty_max = state['repetition_penalty']
|
||||
if state['ban_eos_token']:
|
||||
generator.disallow_tokens([self.tokenizer.eos_token_id])
|
||||
self.generator.disallow_tokens([self.tokenizer.eos_token_id])
|
||||
else:
|
||||
self.generator.disallow_tokens(None)
|
||||
|
||||
text = generator.generate_simple(prompt, max_new_tokens=state['max_new_tokens'])
|
||||
text = self.generator.generate_simple(prompt, max_new_tokens=state['max_new_tokens'])
|
||||
return text
|
||||
|
||||
def generate_with_streaming(self, prompt, state, callback=None):
|
||||
generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
|
||||
generator.settings.temperature = state['temperature']
|
||||
generator.settings.top_p = state['top_p']
|
||||
generator.settings.top_k = state['top_k']
|
||||
generator.settings.typical = state['typical_p']
|
||||
generator.settings.token_repetition_penalty_max = state['repetition_penalty']
|
||||
self.generator.settings.temperature = state['temperature']
|
||||
self.generator.settings.top_p = state['top_p']
|
||||
self.generator.settings.top_k = state['top_k']
|
||||
self.generator.settings.typical = state['typical_p']
|
||||
self.generator.settings.token_repetition_penalty_max = state['repetition_penalty']
|
||||
if state['ban_eos_token']:
|
||||
generator.disallow_tokens([self.tokenizer.eos_token_id])
|
||||
self.generator.disallow_tokens([self.tokenizer.eos_token_id])
|
||||
else:
|
||||
self.generator.disallow_tokens(None)
|
||||
|
||||
generator.end_beam_search()
|
||||
ids = generator.tokenizer.encode(prompt)
|
||||
generator.gen_begin(ids)
|
||||
initial_len = generator.sequence[0].shape[0]
|
||||
for i in range(state['max_new_tokens']):
|
||||
token = generator.gen_single_token()
|
||||
yield (generator.tokenizer.decode(generator.sequence[0][initial_len:]))
|
||||
if token.item() == generator.tokenizer.eos_token_id or shared.stop_everything:
|
||||
self.generator.end_beam_search()
|
||||
ids = self.generator.tokenizer.encode(prompt)
|
||||
self.generator.gen_begin_reuse(ids)
|
||||
initial_len = self.generator.sequence[0].shape[0]
|
||||
for _ in range(state['max_new_tokens']):
|
||||
token = self.generator.gen_single_token()
|
||||
yield (self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:]))
|
||||
if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything:
|
||||
break
|
||||
|
||||
def encode(self, string, **kwargs):
|
||||
|
Loading…
Reference in New Issue
Block a user