Merge remote-tracking branch 'refs/remotes/origin/dev' into dev

This commit is contained in:
oobabooga 2023-12-08 05:02:25 -08:00
commit 00aedf9209
3 changed files with 33 additions and 3 deletions

View File

@ -165,10 +165,19 @@ class ExllamaModel:
if has_leading_space:
decoded_text = ' ' + decoded_text
yield decoded_text
# Check the partial unicode character
if chr(0xfffd) in decoded_text:
is_last = i == max_new_tokens - 1
is_stopping = token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything
# If we are not at the end of the generation, we skip this token
if not (is_last or is_stopping):
continue
if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything:
break
yield decoded_text
# Case 2: CFG
# Copied from https://github.com/turboderp/exllama/blob/master/example_cfg.py
else:
@ -205,6 +214,14 @@ class ExllamaModel:
if has_leading_space:
decoded_text = ' ' + decoded_text
# Check the partial unicode character
if chr(0xfffd) in decoded_text:
is_last = i == max_new_tokens - 1
is_stopping = token.item() == self.tokenizer.eos_token_id or shared.stop_everything
# If we are not at the end of the generation, we skip this token
if not (is_last or is_stopping):
continue
yield decoded_text
if token.item() == self.tokenizer.eos_token_id or shared.stop_everything:
break

View File

@ -138,11 +138,19 @@ class Exllamav2Model:
if has_leading_space:
decoded_text = ' ' + decoded_text
yield decoded_text
# Check the partial unicode character
if chr(0xfffd) in decoded_text:
is_last = i == max_new_tokens - 1
is_stopping = token.item() == self.tokenizer.eos_token_id or shared.stop_everything
# If we are not at the end of the generation, we skip this token
if not (is_last or is_stopping):
continue
if token.item() == self.tokenizer.eos_token_id or shared.stop_everything:
break
yield decoded_text
def generate(self, prompt, state):
output = ''
for output in self.generate_with_streaming(prompt, state):

View File

@ -362,7 +362,12 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
if output[-1] in eos_token_ids:
break
cumulative_reply += get_reply_from_output_ids(output, state, starting_from=starting_from)
new_content = get_reply_from_output_ids(output, state, starting_from=starting_from)
# check the partial unicode character
if chr(0xfffd) in new_content:
continue
cumulative_reply += new_content
starting_from = len(output)
yield cumulative_reply