mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Fix missing bos token for some models (including Llama-3) (#6050)
This commit is contained in:
parent
8df68b05e9
commit
a363cdfca1
@ -138,9 +138,21 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||||||
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
||||||
else:
|
else:
|
||||||
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
|
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
|
||||||
if not add_bos_token:
|
|
||||||
while len(input_ids[0]) > 0 and input_ids[0][0] == shared.tokenizer.bos_token_id:
|
if hasattr(shared.tokenizer, 'bos_token_id'):
|
||||||
input_ids = input_ids[:, 1:]
|
if add_bos_token:
|
||||||
|
if (len(input_ids[0]) > 0 and input_ids[0][0] != shared.tokenizer.bos_token_id) or len(input_ids[0]) == 0:
|
||||||
|
# Add a missing bos token (it may not have been added due to faulty model metadata)
|
||||||
|
bos_tensor = torch.tensor([[shared.tokenizer.bos_token_id]])
|
||||||
|
input_ids = torch.cat((bos_tensor, input_ids), 1)
|
||||||
|
|
||||||
|
# Prevent double bos token due to jinja templates with <s> somewhere
|
||||||
|
while len(input_ids[0]) > 1 and input_ids[0][0] == shared.tokenizer.bos_token_id and input_ids[0][1] == shared.tokenizer.bos_token_id:
|
||||||
|
input_ids = input_ids[:, 1:]
|
||||||
|
else:
|
||||||
|
# Remove any bos token that may have been added
|
||||||
|
while len(input_ids[0]) > 0 and input_ids[0][0] == shared.tokenizer.bos_token_id:
|
||||||
|
input_ids = input_ids[:, 1:]
|
||||||
|
|
||||||
# Handling truncation
|
# Handling truncation
|
||||||
if truncation_length is not None:
|
if truncation_length is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user