2023-05-15 19:19:55 -04:00
|
|
|
import re
|
2023-06-19 20:31:19 -04:00
|
|
|
from functools import partial
|
2023-05-15 19:19:55 -04:00
|
|
|
|
2023-07-19 22:31:19 -04:00
|
|
|
import torch
|
2023-03-19 02:42:10 -04:00
|
|
|
|
2023-03-31 20:18:05 -04:00
|
|
|
from modules import shared
|
2023-03-31 13:27:01 -04:00
|
|
|
from modules.callbacks import Iteratorize
|
2023-05-21 21:42:34 -04:00
|
|
|
from modules.logging_colors import logger
|
2023-03-31 13:27:01 -04:00
|
|
|
|
2023-07-24 10:25:36 -04:00
|
|
|
if torch.cuda.is_available() and not torch.version.hip:
|
|
|
|
try:
|
|
|
|
from llama_cpp_cuda import Llama, LlamaCache, LogitsProcessorList
|
|
|
|
except:
|
|
|
|
from llama_cpp import Llama, LlamaCache, LogitsProcessorList
|
2023-07-19 22:31:19 -04:00
|
|
|
else:
|
|
|
|
from llama_cpp import Llama, LlamaCache, LogitsProcessorList
|
|
|
|
|
2023-03-19 02:42:10 -04:00
|
|
|
|
2023-06-19 20:31:19 -04:00
|
|
|
def ban_eos_logits_processor(eos_token, input_ids, logits):
|
|
|
|
logits[eos_token] = -float('inf')
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
2023-03-19 02:42:10 -04:00
|
|
|
class LlamaCppModel:
|
|
|
|
def __init__(self):
|
|
|
|
self.initialized = False
|
|
|
|
|
2023-05-21 21:42:34 -04:00
|
|
|
def __del__(self):
|
2023-05-15 18:51:23 -04:00
|
|
|
self.model.__del__()
|
|
|
|
|
2023-03-19 02:42:10 -04:00
|
|
|
@classmethod
|
|
|
|
def from_pretrained(self, path):
|
|
|
|
result = self()
|
2023-05-15 19:19:55 -04:00
|
|
|
cache_capacity = 0
|
|
|
|
if shared.args.cache_capacity is not None:
|
|
|
|
if 'GiB' in shared.args.cache_capacity:
|
|
|
|
cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000 * 1000
|
|
|
|
elif 'MiB' in shared.args.cache_capacity:
|
|
|
|
cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000
|
|
|
|
else:
|
|
|
|
cache_capacity = int(shared.args.cache_capacity)
|
|
|
|
|
2023-05-21 21:42:34 -04:00
|
|
|
logger.info("Cache capacity is " + str(cache_capacity) + " bytes")
|
2023-05-02 17:25:28 -04:00
|
|
|
params = {
|
|
|
|
'model_path': str(path),
|
2023-05-25 09:29:31 -04:00
|
|
|
'n_ctx': shared.args.n_ctx,
|
|
|
|
'seed': int(shared.args.llama_cpp_seed),
|
2023-05-02 17:25:28 -04:00
|
|
|
'n_threads': shared.args.threads or None,
|
|
|
|
'n_batch': shared.args.n_batch,
|
|
|
|
'use_mmap': not shared.args.no_mmap,
|
2023-05-14 21:58:11 -04:00
|
|
|
'use_mlock': shared.args.mlock,
|
2023-07-12 10:05:13 -04:00
|
|
|
'low_vram': shared.args.low_vram,
|
2023-07-17 21:32:37 -04:00
|
|
|
'n_gpu_layers': shared.args.n_gpu_layers,
|
|
|
|
'rope_freq_base': 10000 * shared.args.alpha_value ** (64/63.),
|
|
|
|
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
|
2023-07-24 15:37:03 -04:00
|
|
|
'n_gqa': shared.args.n_gqa or None,
|
|
|
|
'rms_norm_eps': shared.args.rms_norm_eps or None,
|
2023-05-02 17:25:28 -04:00
|
|
|
}
|
2023-06-06 12:06:05 -04:00
|
|
|
|
2023-06-20 15:18:42 -04:00
|
|
|
result.model = Llama(**params)
|
2023-05-15 19:19:55 -04:00
|
|
|
if cache_capacity > 0:
|
2023-06-20 15:18:42 -04:00
|
|
|
result.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))
|
2023-05-02 17:25:28 -04:00
|
|
|
|
|
|
|
# This is ugly, but the model and the tokenizer are the same object in this library.
|
|
|
|
return result, result
|
|
|
|
|
|
|
|
def encode(self, string):
|
|
|
|
if type(string) is str:
|
|
|
|
string = string.encode()
|
2023-06-06 12:06:05 -04:00
|
|
|
|
2023-05-02 17:25:28 -04:00
|
|
|
return self.model.tokenize(string)
|
2023-03-19 02:42:10 -04:00
|
|
|
|
2023-07-07 12:11:30 -04:00
|
|
|
def decode(self, tokens):
|
|
|
|
return self.model.detokenize(tokens)
|
|
|
|
|
2023-06-16 19:35:38 -04:00
|
|
|
def generate(self, prompt, state, callback=None):
|
|
|
|
prompt = prompt if type(prompt) is str else prompt.decode()
|
2023-05-15 19:19:55 -04:00
|
|
|
completion_chunks = self.model.create_completion(
|
2023-06-16 19:35:38 -04:00
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=state['max_new_tokens'],
|
|
|
|
temperature=state['temperature'],
|
|
|
|
top_p=state['top_p'],
|
|
|
|
top_k=state['top_k'],
|
|
|
|
repeat_penalty=state['repetition_penalty'],
|
2023-06-17 18:08:25 -04:00
|
|
|
tfs_z=state['tfs'],
|
2023-06-16 19:35:38 -04:00
|
|
|
mirostat_mode=int(state['mirostat_mode']),
|
|
|
|
mirostat_tau=state['mirostat_tau'],
|
|
|
|
mirostat_eta=state['mirostat_eta'],
|
2023-06-19 20:31:19 -04:00
|
|
|
stream=True,
|
|
|
|
logits_processor=LogitsProcessorList([
|
|
|
|
partial(ban_eos_logits_processor, self.model.token_eos()),
|
|
|
|
]) if state['ban_eos_token'] else None,
|
2023-05-15 19:19:55 -04:00
|
|
|
)
|
2023-06-06 12:06:05 -04:00
|
|
|
|
2023-05-15 19:19:55 -04:00
|
|
|
output = ""
|
|
|
|
for completion_chunk in completion_chunks:
|
|
|
|
text = completion_chunk['choices'][0]['text']
|
2023-05-02 17:25:28 -04:00
|
|
|
output += text
|
|
|
|
if callback:
|
2023-05-15 19:19:55 -04:00
|
|
|
callback(text)
|
2023-06-06 12:06:05 -04:00
|
|
|
|
2023-05-15 19:19:55 -04:00
|
|
|
return output
|
2023-03-19 02:42:10 -04:00
|
|
|
|
2023-06-16 19:35:38 -04:00
|
|
|
def generate_with_streaming(self, *args, **kwargs):
|
|
|
|
with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
|
2023-03-31 13:27:01 -04:00
|
|
|
reply = ''
|
2023-03-19 02:42:10 -04:00
|
|
|
for token in generator:
|
|
|
|
reply += token
|
|
|
|
yield reply
|