mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Tokenization improvements
This commit is contained in:
parent
cd08eb0753
commit
ad8ac545a5
@ -202,8 +202,8 @@ class ExllamaModel:
|
|||||||
return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len, add_bos=True)
|
return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len, add_bos=True)
|
||||||
|
|
||||||
def decode(self, ids, **kwargs):
|
def decode(self, ids, **kwargs):
|
||||||
if isinstance(ids, int):
|
if isinstance(ids, list):
|
||||||
ids = torch.tensor([[ids]])
|
ids = torch.tensor([ids])
|
||||||
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
||||||
ids = ids.view(1, -1)
|
ids = ids.view(1, -1)
|
||||||
|
|
||||||
|
@ -107,8 +107,8 @@ class Exllamav2Model:
|
|||||||
return self.tokenizer.encode(string, add_bos=True)
|
return self.tokenizer.encode(string, add_bos=True)
|
||||||
|
|
||||||
def decode(self, ids, **kwargs):
|
def decode(self, ids, **kwargs):
|
||||||
if isinstance(ids, int):
|
if isinstance(ids, list):
|
||||||
ids = torch.tensor([[ids]])
|
ids = torch.tensor([ids])
|
||||||
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
||||||
ids = ids.view(1, -1)
|
ids = ids.view(1, -1)
|
||||||
|
|
||||||
|
@ -98,8 +98,8 @@ class LlamaCppModel:
|
|||||||
|
|
||||||
return self.model.tokenize(string)
|
return self.model.tokenize(string)
|
||||||
|
|
||||||
def decode(self, tokens):
|
def decode(self, ids):
|
||||||
return self.model.detokenize(tokens)
|
return self.model.detokenize(ids).decode('utf-8')
|
||||||
|
|
||||||
def get_logits(self, tokens):
|
def get_logits(self, tokens):
|
||||||
self.model.eval(tokens)
|
self.model.eval(tokens)
|
||||||
|
@ -46,17 +46,14 @@ def get_next_logits(prompt, state, use_samplers, previous):
|
|||||||
scores = output['logits'][-1][-1]
|
scores = output['logits'][-1][-1]
|
||||||
|
|
||||||
probs = torch.softmax(scores, dim=-1, dtype=torch.float)
|
probs = torch.softmax(scores, dim=-1, dtype=torch.float)
|
||||||
topk_values, topk_indices = torch.topk(probs, k=25, largest=True, sorted=True)
|
topk_values, topk_indices = torch.topk(probs, k=50, largest=True, sorted=True)
|
||||||
topk_values = [f"{float(i):.5f}" for i in topk_values]
|
topk_values = [f"{float(i):.5f}" for i in topk_values]
|
||||||
if is_non_hf_exllamav1 or is_non_hf_llamacpp:
|
if is_non_hf_exllamav1 or is_non_hf_llamacpp:
|
||||||
topk_indices = [i.expand((1, 1)) for i in topk_indices]
|
topk_indices = [i.expand((1, 1)) for i in topk_indices]
|
||||||
|
|
||||||
tokens = [shared.tokenizer.decode(i) for i in topk_indices]
|
tokens = [shared.tokenizer.decode(i) for i in topk_indices]
|
||||||
if is_non_hf_llamacpp:
|
|
||||||
tokens = [i.decode('utf-8') for i in tokens] # llamacpp returns bytes, not str
|
|
||||||
|
|
||||||
output = ''
|
output = ''
|
||||||
for row in list(zip(topk_values, tokens)):
|
for row in list(zip(topk_values, tokens)):
|
||||||
output += f"{row[0]} - {repr(row[1])[1:-1]}\n"
|
output += f"{row[0]} - {repr(row[1])}\n"
|
||||||
|
|
||||||
return output, previous
|
return output, previous
|
||||||
|
@ -39,8 +39,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||||||
if generate_func is None:
|
if generate_func is None:
|
||||||
if shared.model_name == 'None' or shared.model is None:
|
if shared.model_name == 'None' or shared.model is None:
|
||||||
logger.error("No model is loaded! Select one in the Model tab.")
|
logger.error("No model is loaded! Select one in the Model tab.")
|
||||||
yield ''
|
raise ValueError('No model is loaded! Select one in the Model tab.')
|
||||||
return
|
|
||||||
|
|
||||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel']:
|
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel']:
|
||||||
generate_func = generate_reply_custom
|
generate_func = generate_reply_custom
|
||||||
@ -106,6 +105,10 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||||||
|
|
||||||
|
|
||||||
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
||||||
|
if shared.tokenizer is None:
|
||||||
|
logger.error('No tokenizer is loaded')
|
||||||
|
raise ValueError('No tokenizer is loaded')
|
||||||
|
|
||||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'CtransformersModel', 'Exllamav2Model']:
|
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'CtransformersModel', 'Exllamav2Model']:
|
||||||
input_ids = shared.tokenizer.encode(str(prompt))
|
input_ids = shared.tokenizer.encode(str(prompt))
|
||||||
if shared.model.__class__.__name__ not in ['Exllamav2Model']:
|
if shared.model.__class__.__name__ not in ['Exllamav2Model']:
|
||||||
@ -133,6 +136,10 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||||||
|
|
||||||
|
|
||||||
def decode(output_ids, skip_special_tokens=True):
|
def decode(output_ids, skip_special_tokens=True):
|
||||||
|
if shared.tokenizer is None:
|
||||||
|
logger.error('No tokenizer is loaded')
|
||||||
|
raise ValueError('No tokenizer is loaded')
|
||||||
|
|
||||||
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
||||||
|
|
||||||
|
|
||||||
@ -146,11 +153,11 @@ def get_encoded_length(prompt):
|
|||||||
|
|
||||||
def get_token_ids(prompt):
|
def get_token_ids(prompt):
|
||||||
tokens = encode(prompt)[0]
|
tokens = encode(prompt)[0]
|
||||||
decoded_tokens = [shared.tokenizer.decode(i) for i in tokens]
|
decoded_tokens = [shared.tokenizer.decode([i]) for i in tokens]
|
||||||
|
|
||||||
output = ''
|
output = ''
|
||||||
for row in list(zip(tokens, decoded_tokens)):
|
for row in list(zip(tokens, decoded_tokens)):
|
||||||
output += f"{str(int(row[0])).ljust(5)} - {repr(row[1])[1:-1]}\n"
|
output += f"{str(int(row[0])).ljust(5)} - {repr(row[1])}\n"
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user