mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
fix: grammar not support utf-8 (#5900)
This commit is contained in:
parent
8456d13349
commit
5cb59707f3
@ -60,7 +60,7 @@ def hex_to_int(c):
|
|||||||
return int(c)
|
return int(c)
|
||||||
elif "a" <= c.lower() <= "f":
|
elif "a" <= c.lower() <= "f":
|
||||||
return ord(c.lower()) - ord("a") + 10
|
return ord(c.lower()) - ord("a") + 10
|
||||||
return -1
|
raise RuntimeError("unknown hex char " + c)
|
||||||
|
|
||||||
|
|
||||||
def remove_leading_white_space(src, newline_ok):
|
def remove_leading_white_space(src, newline_ok):
|
||||||
@ -100,6 +100,13 @@ def parse_name(src):
|
|||||||
return src[:pos], src[pos:]
|
return src[:pos], src[pos:]
|
||||||
|
|
||||||
|
|
||||||
|
def read_hex(s):
|
||||||
|
val = 0
|
||||||
|
for c in s:
|
||||||
|
val = (val << 4) + hex_to_int(c)
|
||||||
|
return chr(val)
|
||||||
|
|
||||||
|
|
||||||
def parse_char(src):
|
def parse_char(src):
|
||||||
"""
|
"""
|
||||||
parse the leading char from the input string
|
parse the leading char from the input string
|
||||||
@ -111,13 +118,12 @@ def parse_char(src):
|
|||||||
if src[0] == "\\":
|
if src[0] == "\\":
|
||||||
esc = src[1]
|
esc = src[1]
|
||||||
if esc == "x":
|
if esc == "x":
|
||||||
first = hex_to_int(src[2])
|
return read_hex(src[2:4]), src[4:]
|
||||||
if first > -1:
|
elif esc == "u":
|
||||||
second = hex_to_int(src[3])
|
return read_hex(src[2:6]), src[6:]
|
||||||
if second > -1:
|
elif esc == "U":
|
||||||
return (first << 4) + second, src[4:]
|
return read_hex(src[2:10]), src[10:]
|
||||||
raise RuntimeError("expecting \\xNN at " + src)
|
elif esc in ('"', "[", "]", "\\", "-"):
|
||||||
elif esc in ('"', "[", "]"):
|
|
||||||
return esc, src[2:]
|
return esc, src[2:]
|
||||||
elif esc == "r":
|
elif esc == "r":
|
||||||
return "\r", src[2:]
|
return "\r", src[2:]
|
||||||
@ -454,7 +460,8 @@ class IncrementalGrammarConstraint(GrammarConstraint):
|
|||||||
def __init__(self, grammar_str, start_rule_name, tokenizer):
|
def __init__(self, grammar_str, start_rule_name, tokenizer):
|
||||||
super().__init__(grammar_str, start_rule_name, tokenizer)
|
super().__init__(grammar_str, start_rule_name, tokenizer)
|
||||||
|
|
||||||
def accept_char(self, byte, stacks):
|
def accept_char(self, char, stacks):
|
||||||
|
byte = ord(char)
|
||||||
new_stacks = []
|
new_stacks = []
|
||||||
for stack in stacks:
|
for stack in stacks:
|
||||||
# stack is empty
|
# stack is empty
|
||||||
@ -471,6 +478,9 @@ class IncrementalGrammarConstraint(GrammarConstraint):
|
|||||||
if self.grammar_encoding[pos + i] <= byte and byte <= self.grammar_encoding[pos + i + 1]:
|
if self.grammar_encoding[pos + i] <= byte and byte <= self.grammar_encoding[pos + i + 1]:
|
||||||
found = True
|
found = True
|
||||||
break
|
break
|
||||||
|
if self.grammar_encoding[pos + i] >= byte and byte >= self.grammar_encoding[pos + i + 1]:
|
||||||
|
found = True
|
||||||
|
break
|
||||||
if not found:
|
if not found:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -483,9 +493,8 @@ class IncrementalGrammarConstraint(GrammarConstraint):
|
|||||||
return new_stacks
|
return new_stacks
|
||||||
|
|
||||||
def accept_string(self, string: str, stacks: List[List[int]]):
|
def accept_string(self, string: str, stacks: List[List[int]]):
|
||||||
_bytes = bytes(string, "utf-8")
|
for char in string:
|
||||||
for byte in _bytes:
|
stacks = self.accept_char(char, stacks)
|
||||||
stacks = self.accept_char(byte, stacks)
|
|
||||||
return stacks
|
return stacks
|
||||||
|
|
||||||
def accept_token_id(self, token_id: int, stacks: List[List[int]]):
|
def accept_token_id(self, token_id: int, stacks: List[List[int]]):
|
||||||
@ -537,16 +546,18 @@ class IncrementalGrammarConstraint(GrammarConstraint):
|
|||||||
|
|
||||||
# For each sub-rule in the grammar, cache whether each byte is accepted.
|
# For each sub-rule in the grammar, cache whether each byte is accepted.
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def pos_char_acceptance(self, pos):
|
def pos_char_acceptance(self, pos, char):
|
||||||
acceptance = [False] * 256
|
byte = ord(char)
|
||||||
num_chars = self.grammar_encoding[pos]
|
num_chars = self.grammar_encoding[pos]
|
||||||
pos += 1
|
pos += 1
|
||||||
for i in range(0, num_chars, 2):
|
for i in range(0, num_chars, 2):
|
||||||
start = self.grammar_encoding[pos + i]
|
start = self.grammar_encoding[pos + i]
|
||||||
end = self.grammar_encoding[pos + i + 1]
|
end = self.grammar_encoding[pos + i + 1]
|
||||||
for j in range(start, end + 1):
|
if byte >= start and byte <= end:
|
||||||
acceptance[j] = True
|
return True
|
||||||
return acceptance
|
if byte <= start and byte >= end:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
# Probably this should be configurable. If the grammar has an exceedingly
|
# Probably this should be configurable. If the grammar has an exceedingly
|
||||||
# large number of states, the correct setting is a tradeoff between GPU
|
# large number of states, the correct setting is a tradeoff between GPU
|
||||||
@ -580,7 +591,7 @@ class IncrementalGrammarConstraint(GrammarConstraint):
|
|||||||
pos = stk[-1]
|
pos = stk[-1]
|
||||||
num_chars = self.grammar_encoding[pos]
|
num_chars = self.grammar_encoding[pos]
|
||||||
|
|
||||||
if not self.pos_char_acceptance(pos)[byte]:
|
if not self.pos_char_acceptance(pos, byte):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
pos += num_chars + 1
|
pos += num_chars + 1
|
||||||
@ -657,14 +668,14 @@ class TokenTrie:
|
|||||||
token = tokenizer.convert_ids_to_tokens(id)
|
token = tokenizer.convert_ids_to_tokens(id)
|
||||||
token = re.sub(r"<0x([0-9a-fA-F]{2})>", replace_hex, token)
|
token = re.sub(r"<0x([0-9a-fA-F]{2})>", replace_hex, token)
|
||||||
token = token.replace("▁", " ")
|
token = token.replace("▁", " ")
|
||||||
return bytes(token, "utf-8")
|
return token
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print("Warning: unrecognized tokenizer: using default token formatting")
|
print("Warning: unrecognized tokenizer: using default token formatting")
|
||||||
|
|
||||||
def fmt_token(id):
|
def fmt_token(id):
|
||||||
token = tokenizer.convert_ids_to_tokens(id)
|
token = tokenizer.convert_ids_to_tokens(id)
|
||||||
return bytes(token, "utf-8")
|
return token
|
||||||
|
|
||||||
# note: vocab_size doesn't work here because there are also
|
# note: vocab_size doesn't work here because there are also
|
||||||
# get_added_vocab() tokens
|
# get_added_vocab() tokens
|
||||||
|
Loading…
Reference in New Issue
Block a user