fix: grammar not support utf-8 (#5900)

This commit is contained in:
A0nameless0man 2024-05-20 07:10:39 +08:00 committed by GitHub
parent 8456d13349
commit 5cb59707f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -60,7 +60,7 @@ def hex_to_int(c):
return int(c)
elif "a" <= c.lower() <= "f":
return ord(c.lower()) - ord("a") + 10
return -1
raise RuntimeError("unknown hex char " + c)
def remove_leading_white_space(src, newline_ok):
@ -100,6 +100,13 @@ def parse_name(src):
return src[:pos], src[pos:]
def read_hex(s):
val = 0
for c in s:
val = (val << 4) + hex_to_int(c)
return chr(val)
def parse_char(src):
"""
parse the leading char from the input string
@ -111,13 +118,12 @@ def parse_char(src):
if src[0] == "\\":
esc = src[1]
if esc == "x":
first = hex_to_int(src[2])
if first > -1:
second = hex_to_int(src[3])
if second > -1:
return (first << 4) + second, src[4:]
raise RuntimeError("expecting \\xNN at " + src)
elif esc in ('"', "[", "]"):
return read_hex(src[2:4]), src[4:]
elif esc == "u":
return read_hex(src[2:6]), src[6:]
elif esc == "U":
return read_hex(src[2:10]), src[10:]
elif esc in ('"', "[", "]", "\\", "-"):
return esc, src[2:]
elif esc == "r":
return "\r", src[2:]
@ -454,7 +460,8 @@ class IncrementalGrammarConstraint(GrammarConstraint):
def __init__(self, grammar_str, start_rule_name, tokenizer):
super().__init__(grammar_str, start_rule_name, tokenizer)
def accept_char(self, byte, stacks):
def accept_char(self, char, stacks):
byte = ord(char)
new_stacks = []
for stack in stacks:
# stack is empty
@ -471,6 +478,9 @@ class IncrementalGrammarConstraint(GrammarConstraint):
if self.grammar_encoding[pos + i] <= byte and byte <= self.grammar_encoding[pos + i + 1]:
found = True
break
if self.grammar_encoding[pos + i] >= byte and byte >= self.grammar_encoding[pos + i + 1]:
found = True
break
if not found:
continue
@ -483,9 +493,8 @@ class IncrementalGrammarConstraint(GrammarConstraint):
return new_stacks
def accept_string(self, string: str, stacks: List[List[int]]):
_bytes = bytes(string, "utf-8")
for byte in _bytes:
stacks = self.accept_char(byte, stacks)
for char in string:
stacks = self.accept_char(char, stacks)
return stacks
def accept_token_id(self, token_id: int, stacks: List[List[int]]):
@ -537,16 +546,18 @@ class IncrementalGrammarConstraint(GrammarConstraint):
# For each sub-rule in the grammar, cache whether each byte is accepted.
@lru_cache(maxsize=None)
def pos_char_acceptance(self, pos):
acceptance = [False] * 256
def pos_char_acceptance(self, pos, char):
byte = ord(char)
num_chars = self.grammar_encoding[pos]
pos += 1
for i in range(0, num_chars, 2):
start = self.grammar_encoding[pos + i]
end = self.grammar_encoding[pos + i + 1]
for j in range(start, end + 1):
acceptance[j] = True
return acceptance
if byte >= start and byte <= end:
return True
if byte <= start and byte >= end:
return True
return False
# Probably this should be configurable. If the grammar has an exceedingly
# large number of states, the correct setting is a tradeoff between GPU
@ -580,7 +591,7 @@ class IncrementalGrammarConstraint(GrammarConstraint):
pos = stk[-1]
num_chars = self.grammar_encoding[pos]
if not self.pos_char_acceptance(pos)[byte]:
if not self.pos_char_acceptance(pos, byte):
continue
pos += num_chars + 1
@ -657,14 +668,14 @@ class TokenTrie:
token = tokenizer.convert_ids_to_tokens(id)
token = re.sub(r"<0x([0-9a-fA-F]{2})>", replace_hex, token)
token = token.replace("", " ")
return bytes(token, "utf-8")
return token
else:
print("Warning: unrecognized tokenizer: using default token formatting")
def fmt_token(id):
token = tokenizer.convert_ids_to_tokens(id)
return bytes(token, "utf-8")
return token
# note: vocab_size doesn't work here because there are also
# get_added_vocab() tokens