Make stop_everything work with non-streamed generation (#2848)

This commit is contained in:
快乐的我531 2023-06-24 22:19:16 +08:00 committed by GitHub
parent ec482f3dae
commit e356f69b36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 2 deletions

View File

@ -9,6 +9,14 @@ import transformers
import modules.shared as shared
class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
def __init__(self):
transformers.StoppingCriteria.__init__(self)
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
return shared.stop_everything
class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func

View File

@ -9,7 +9,8 @@ import torch
import transformers
import modules.shared as shared
from modules.callbacks import Iteratorize, Stream
from modules.callbacks import (Iteratorize, Stream,
_StopEverythingStoppingCriteria)
from modules.extensions import apply_extensions
from modules.html_generator import generate_4chan_html, generate_basic_html
from modules.logging_colors import logger
@ -252,10 +253,11 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
if inputs_embeds is not None:
generate_params.update({'inputs_embeds': inputs_embeds})
# Find the eos tokens
# Stopping criteria / eos token
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
generate_params['eos_token_id'] = eos_token_ids
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria());
t0 = time.time()
try: