text-generation-webui/modules/llamacpp_model.py

192 lines
6.3 KiB
Python
Raw Normal View History

import re
from functools import partial
2023-09-17 09:42:32 -04:00
import numpy as np
import torch
2023-03-19 02:42:10 -04:00
from modules import RoPE, llama_cpp_python_hijack, shared
2023-03-31 13:27:01 -04:00
from modules.callbacks import Iteratorize
from modules.logging_colors import logger
from modules.text_generation import get_max_prompt_length
2023-03-31 13:27:01 -04:00
try:
import llama_cpp
except:
llama_cpp = None
try:
import llama_cpp_cuda
except:
llama_cpp_cuda = None
try:
import llama_cpp_cuda_tensorcores
except:
llama_cpp_cuda_tensorcores = None
def llama_cpp_lib():
if shared.args.cpu and llama_cpp is not None:
return llama_cpp
elif shared.args.tensorcores and llama_cpp_cuda_tensorcores is not None:
return llama_cpp_cuda_tensorcores
elif llama_cpp_cuda is not None:
return llama_cpp_cuda
else:
return llama_cpp
2023-03-19 02:42:10 -04:00
def ban_eos_logits_processor(eos_token, input_ids, logits):
logits[eos_token] = -float('inf')
return logits
2023-09-15 17:27:27 -04:00
def custom_token_ban_logits_processor(token_ids, input_ids, logits):
for token_id in token_ids:
logits[token_id] = -float('inf')
return logits
2023-03-19 02:42:10 -04:00
class LlamaCppModel:
def __init__(self):
self.initialized = False
2023-09-24 17:05:24 -04:00
self.grammar_string = ''
self.grammar = None
2023-03-19 02:42:10 -04:00
def __del__(self):
2023-11-29 18:19:48 -05:00
del self.model
2023-03-19 02:42:10 -04:00
@classmethod
def from_pretrained(self, path):
Llama = llama_cpp_lib().Llama
LlamaCache = llama_cpp_lib().LlamaCache
2023-03-19 02:42:10 -04:00
result = self()
cache_capacity = 0
if shared.args.cache_capacity is not None:
if 'GiB' in shared.args.cache_capacity:
cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000 * 1000
elif 'MiB' in shared.args.cache_capacity:
cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000
else:
cache_capacity = int(shared.args.cache_capacity)
2023-11-25 09:33:37 -05:00
if cache_capacity > 0:
logger.info("Cache capacity is " + str(cache_capacity) + " bytes")
if shared.args.tensor_split is None or shared.args.tensor_split.strip() == '':
tensor_split_list = None
else:
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")]
params = {
'model_path': str(path),
'n_ctx': shared.args.n_ctx,
'n_threads': shared.args.threads or None,
2023-10-02 00:27:04 -04:00
'n_threads_batch': shared.args.threads_batch or None,
'n_batch': shared.args.n_batch,
'use_mmap': not shared.args.no_mmap,
'use_mlock': shared.args.mlock,
'mul_mat_q': not shared.args.no_mul_mat_q,
'numa': shared.args.numa,
'n_gpu_layers': shared.args.n_gpu_layers,
'rope_freq_base': RoPE.get_rope_freq_base(shared.args.alpha_value, shared.args.rope_freq_base),
'tensor_split': tensor_split_list,
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
'offload_kqv': not shared.args.no_offload_kqv,
'split_mode': 1 if not shared.args.row_split else 2,
'flash_attn': shared.args.flash_attn
}
2023-08-27 01:11:07 -04:00
result.model = Llama(**params)
if cache_capacity > 0:
result.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))
# This is ugly, but the model and the tokenizer are the same object in this library.
return result, result
def encode(self, string):
if type(string) is str:
string = string.encode()
2023-06-06 12:06:05 -04:00
return self.model.tokenize(string)
2023-03-19 02:42:10 -04:00
def decode(self, ids, **kwargs):
2023-09-17 10:01:34 -04:00
return self.model.detokenize(ids).decode('utf-8')
2023-09-17 09:42:32 -04:00
def get_logits(self, tokens):
self.model.reset()
2023-09-17 09:42:32 -04:00
self.model.eval(tokens)
logits = self.model._scores
logits = np.expand_dims(logits, 0) # batch dim is expected
return torch.tensor(logits, dtype=torch.float32)
2023-09-24 17:05:24 -04:00
def load_grammar(self, string):
if string != self.grammar_string:
self.grammar_string = string
if string.strip() != '':
self.grammar = llama_cpp_lib().LlamaGrammar.from_string(string)
else:
self.grammar = None
2023-06-16 19:35:38 -04:00
def generate(self, prompt, state, callback=None):
LogitsProcessorList = llama_cpp_lib().LogitsProcessorList
2023-06-16 19:35:38 -04:00
prompt = prompt if type(prompt) is str else prompt.decode()
# Handle truncation
prompt = self.encode(prompt)
prompt = prompt[-get_max_prompt_length(state):]
2023-09-17 16:07:48 -04:00
prompt = self.decode(prompt)
2023-09-24 17:05:24 -04:00
self.load_grammar(state['grammar_string'])
2023-09-15 17:27:27 -04:00
logit_processors = LogitsProcessorList()
if state['ban_eos_token']:
2023-09-18 11:15:02 -04:00
logit_processors.append(partial(ban_eos_logits_processor, self.model.token_eos()))
2023-09-15 17:27:27 -04:00
if state['custom_token_bans']:
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
if len(to_ban) > 0:
logit_processors.append(partial(custom_token_ban_logits_processor, to_ban))
completion_chunks = self.model.create_completion(
2023-06-16 19:35:38 -04:00
prompt=prompt,
max_tokens=state['max_new_tokens'],
temperature=state['temperature'],
top_p=state['top_p'],
min_p=state['min_p'],
typical_p=state['typical_p'],
frequency_penalty=state['frequency_penalty'],
presence_penalty=state['presence_penalty'],
repeat_penalty=state['repetition_penalty'],
top_k=state['top_k'],
stream=True,
seed=int(state['seed']) if state['seed'] != -1 else None,
2023-06-17 18:08:25 -04:00
tfs_z=state['tfs'],
2023-06-16 19:35:38 -04:00
mirostat_mode=int(state['mirostat_mode']),
mirostat_tau=state['mirostat_tau'],
mirostat_eta=state['mirostat_eta'],
2023-09-15 17:27:27 -04:00
logits_processor=logit_processors,
grammar=self.grammar
)
2023-06-06 12:06:05 -04:00
output = ""
for completion_chunk in completion_chunks:
if shared.stop_everything:
break
2023-11-25 09:33:37 -05:00
text = completion_chunk['choices'][0]['text']
output += text
if callback:
callback(text)
2023-06-06 12:06:05 -04:00
return output
2023-03-19 02:42:10 -04:00
2023-06-16 19:35:38 -04:00
def generate_with_streaming(self, *args, **kwargs):
with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
2023-03-31 13:27:01 -04:00
reply = ''
2023-03-19 02:42:10 -04:00
for token in generator:
reply += token
yield reply