text-generation-webui/modules/llamacpp_model.py

102 lines
3.3 KiB
Python
Raw Normal View History

'''
Based on
https://github.com/abetlen/llama-cpp-python
2023-03-31 20:18:05 -04:00
Documentation:
https://abetlen.github.io/llama-cpp-python/
'''
import re
from functools import partial
from llama_cpp import Llama, LlamaCache, LogitsProcessorList
2023-03-19 02:42:10 -04:00
2023-03-31 20:18:05 -04:00
from modules import shared
2023-03-31 13:27:01 -04:00
from modules.callbacks import Iteratorize
from modules.logging_colors import logger
2023-03-31 13:27:01 -04:00
2023-03-19 02:42:10 -04:00
def ban_eos_logits_processor(eos_token, input_ids, logits):
logits[eos_token] = -float('inf')
return logits
2023-03-19 02:42:10 -04:00
class LlamaCppModel:
def __init__(self):
self.initialized = False
def __del__(self):
self.model.__del__()
2023-03-19 02:42:10 -04:00
@classmethod
def from_pretrained(self, path):
result = self()
cache_capacity = 0
if shared.args.cache_capacity is not None:
if 'GiB' in shared.args.cache_capacity:
cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000 * 1000
elif 'MiB' in shared.args.cache_capacity:
cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000
else:
cache_capacity = int(shared.args.cache_capacity)
logger.info("Cache capacity is " + str(cache_capacity) + " bytes")
params = {
'model_path': str(path),
'n_ctx': shared.args.n_ctx,
'seed': int(shared.args.llama_cpp_seed),
'n_threads': shared.args.threads or None,
'n_batch': shared.args.n_batch,
'use_mmap': not shared.args.no_mmap,
'use_mlock': shared.args.mlock,
'n_gpu_layers': shared.args.n_gpu_layers
}
2023-06-06 12:06:05 -04:00
result.model = Llama(**params)
if cache_capacity > 0:
result.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))
# This is ugly, but the model and the tokenizer are the same object in this library.
return result, result
def encode(self, string):
if type(string) is str:
string = string.encode()
2023-06-06 12:06:05 -04:00
return self.model.tokenize(string)
2023-03-19 02:42:10 -04:00
2023-06-16 19:35:38 -04:00
def generate(self, prompt, state, callback=None):
prompt = prompt if type(prompt) is str else prompt.decode()
completion_chunks = self.model.create_completion(
2023-06-16 19:35:38 -04:00
prompt=prompt,
max_tokens=state['max_new_tokens'],
temperature=state['temperature'],
top_p=state['top_p'],
top_k=state['top_k'],
repeat_penalty=state['repetition_penalty'],
2023-06-17 18:08:25 -04:00
tfs_z=state['tfs'],
2023-06-16 19:35:38 -04:00
mirostat_mode=int(state['mirostat_mode']),
mirostat_tau=state['mirostat_tau'],
mirostat_eta=state['mirostat_eta'],
stream=True,
logits_processor=LogitsProcessorList([
partial(ban_eos_logits_processor, self.model.token_eos()),
]) if state['ban_eos_token'] else None,
)
2023-06-06 12:06:05 -04:00
output = ""
for completion_chunk in completion_chunks:
text = completion_chunk['choices'][0]['text']
output += text
if callback:
callback(text)
2023-06-06 12:06:05 -04:00
return output
2023-03-19 02:42:10 -04:00
2023-06-16 19:35:38 -04:00
def generate_with_streaming(self, *args, **kwargs):
with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
2023-03-31 13:27:01 -04:00
reply = ''
2023-03-19 02:42:10 -04:00
for token in generator:
reply += token
yield reply