mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
34 lines
869 B
Python
34 lines
869 B
Python
|
from torch_grammar import GrammarSampler
|
||
|
from transformers.generation.logits_process import LogitsProcessor
|
||
|
|
||
|
from modules import shared
|
||
|
|
||
|
sampler = None
|
||
|
grammar = None
|
||
|
grammar_string = ''
|
||
|
|
||
|
|
||
|
class GrammarLogitsProcessor(LogitsProcessor):
|
||
|
def __init__(self, string):
|
||
|
|
||
|
global sampler, grammar, grammar_string
|
||
|
|
||
|
if string != grammar_string:
|
||
|
grammar_string = string
|
||
|
if string.strip() != '':
|
||
|
string = string.strip() + '\n'
|
||
|
sampler = GrammarSampler(string, 'root', shared.tokenizer)
|
||
|
else:
|
||
|
sampler = None
|
||
|
|
||
|
if sampler is not None:
|
||
|
grammar = sampler.logits_processor()
|
||
|
else:
|
||
|
grammar = None
|
||
|
|
||
|
def __call__(self, input_ids, scores):
|
||
|
if grammar is not None:
|
||
|
scores = grammar(input_ids, scores)
|
||
|
|
||
|
return scores
|