2023-06-16 19:35:38 -04:00
|
|
|
from pathlib import Path
|
|
|
|
|
2023-09-16 08:42:38 -04:00
|
|
|
import torch
|
2023-08-06 16:22:48 -04:00
|
|
|
import torch.nn.functional as F
|
2023-06-29 14:03:16 -04:00
|
|
|
from torch import version as torch_version
|
|
|
|
|
2023-09-11 10:57:38 -04:00
|
|
|
from modules import shared
|
2023-06-16 19:35:38 -04:00
|
|
|
from modules.logging_colors import logger
|
2023-08-06 16:22:48 -04:00
|
|
|
from modules.models import clear_torch_cache
|
2023-07-07 12:09:23 -04:00
|
|
|
from modules.text_generation import get_max_prompt_length
|
2023-06-16 19:49:36 -04:00
|
|
|
|
2023-06-24 19:24:17 -04:00
|
|
|
try:
|
|
|
|
from exllama.generator import ExLlamaGenerator
|
|
|
|
from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig
|
|
|
|
from exllama.tokenizer import ExLlamaTokenizer
|
|
|
|
except:
|
2023-09-17 18:26:05 -04:00
|
|
|
logger.warning('exllama module failed to import. Will attempt to import from repositories/.')
|
2023-06-24 19:24:17 -04:00
|
|
|
try:
|
|
|
|
from modules.relative_imports import RelativeImport
|
|
|
|
|
|
|
|
with RelativeImport("repositories/exllama"):
|
|
|
|
from generator import ExLlamaGenerator
|
|
|
|
from model import ExLlama, ExLlamaCache, ExLlamaConfig
|
|
|
|
from tokenizer import ExLlamaTokenizer
|
|
|
|
except:
|
2023-09-17 18:26:05 -04:00
|
|
|
logger.error(
|
|
|
|
"Could not find repositories/exllama. Please ensure that exllama"
|
|
|
|
" (https://github.com/turboderp/exllama) is cloned inside repositories/ and is up to date."
|
|
|
|
)
|
2023-06-24 19:24:17 -04:00
|
|
|
raise
|
2023-06-16 19:35:38 -04:00
|
|
|
|
|
|
|
|
|
|
|
class ExllamaModel:
|
|
|
|
def __init__(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_pretrained(self, path_to_model):
|
|
|
|
|
2023-06-18 12:26:30 -04:00
|
|
|
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
|
2023-06-16 19:35:38 -04:00
|
|
|
tokenizer_model_path = path_to_model / "tokenizer.model"
|
|
|
|
model_config_path = path_to_model / "config.json"
|
|
|
|
|
|
|
|
# Find the model checkpoint
|
|
|
|
model_path = None
|
|
|
|
for ext in ['.safetensors', '.pt', '.bin']:
|
|
|
|
found = list(path_to_model.glob(f"*{ext}"))
|
|
|
|
if len(found) > 0:
|
|
|
|
if len(found) > 1:
|
|
|
|
logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
|
|
|
|
|
|
|
|
model_path = found[-1]
|
|
|
|
break
|
|
|
|
|
|
|
|
config = ExLlamaConfig(str(model_config_path))
|
|
|
|
config.model_path = str(model_path)
|
2023-06-25 21:49:26 -04:00
|
|
|
config.max_seq_len = shared.args.max_seq_len
|
|
|
|
config.compress_pos_emb = shared.args.compress_pos_emb
|
2023-06-16 19:49:36 -04:00
|
|
|
if shared.args.gpu_split:
|
|
|
|
config.set_auto_map(shared.args.gpu_split)
|
|
|
|
config.gpu_peer_fix = True
|
2023-07-04 00:13:16 -04:00
|
|
|
|
2023-09-11 00:14:36 -04:00
|
|
|
if shared.args.alpha_value > 1 and shared.args.rope_freq_base == 0:
|
|
|
|
config.alpha_value = shared.args.alpha_value
|
2023-07-04 00:13:16 -04:00
|
|
|
config.calculate_rotary_embedding_base()
|
2023-09-11 00:14:36 -04:00
|
|
|
elif shared.args.rope_freq_base > 0:
|
|
|
|
config.rotary_embedding_base = shared.args.rope_freq_base
|
2023-07-04 00:13:16 -04:00
|
|
|
|
2023-06-29 14:03:16 -04:00
|
|
|
if torch_version.hip:
|
|
|
|
config.rmsnorm_no_half2 = True
|
|
|
|
config.rope_no_half2 = True
|
|
|
|
config.matmul_no_half2 = True
|
|
|
|
config.silu_no_half2 = True
|
|
|
|
|
2023-06-16 19:35:38 -04:00
|
|
|
model = ExLlama(config)
|
|
|
|
tokenizer = ExLlamaTokenizer(str(tokenizer_model_path))
|
|
|
|
cache = ExLlamaCache(model)
|
2023-06-17 17:00:10 -04:00
|
|
|
generator = ExLlamaGenerator(model, tokenizer, cache)
|
2023-06-16 19:35:38 -04:00
|
|
|
|
|
|
|
result = self()
|
|
|
|
result.config = config
|
|
|
|
result.model = model
|
|
|
|
result.cache = cache
|
|
|
|
result.tokenizer = tokenizer
|
2023-06-19 00:19:28 -04:00
|
|
|
result.generator = generator
|
2023-06-16 19:35:38 -04:00
|
|
|
return result, result
|
|
|
|
|
2023-09-19 16:13:03 -04:00
|
|
|
def encode(self, string, **kwargs):
|
|
|
|
return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len, add_bos=True)
|
|
|
|
|
|
|
|
def decode(self, ids, **kwargs):
|
|
|
|
if isinstance(ids, list):
|
|
|
|
ids = torch.tensor([ids])
|
|
|
|
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
|
|
|
ids = ids.view(1, -1)
|
|
|
|
|
|
|
|
return self.tokenizer.decode(ids)[0]
|
|
|
|
|
|
|
|
def get_logits(self, token_ids, **kwargs):
|
|
|
|
self.cache.current_seq_len = 0
|
2023-10-14 15:12:41 -04:00
|
|
|
if token_ids.shape[-1] > 1:
|
|
|
|
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
|
|
|
|
2023-09-19 16:13:03 -04:00
|
|
|
return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu()
|
|
|
|
|
2023-06-17 18:02:08 -04:00
|
|
|
def generate_with_streaming(self, prompt, state):
|
2023-08-06 16:22:48 -04:00
|
|
|
|
|
|
|
# The cache batch size must be 2 for CFG and 1 otherwise
|
|
|
|
if state['guidance_scale'] == 1:
|
|
|
|
if self.cache.batch_size == 2:
|
|
|
|
del self.cache
|
|
|
|
clear_torch_cache()
|
|
|
|
self.cache = ExLlamaCache(self.model)
|
|
|
|
self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
|
|
|
|
else:
|
|
|
|
if self.cache.batch_size == 1:
|
|
|
|
del self.cache
|
|
|
|
clear_torch_cache()
|
|
|
|
self.cache = ExLlamaCache(self.model, batch_size=2)
|
|
|
|
self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
|
|
|
|
|
2023-06-17 17:00:10 -04:00
|
|
|
self.generator.settings.temperature = state['temperature']
|
|
|
|
self.generator.settings.top_p = state['top_p']
|
|
|
|
self.generator.settings.top_k = state['top_k']
|
|
|
|
self.generator.settings.typical = state['typical_p']
|
|
|
|
self.generator.settings.token_repetition_penalty_max = state['repetition_penalty']
|
2023-06-29 12:53:06 -04:00
|
|
|
self.generator.settings.token_repetition_penalty_sustain = -1 if state['repetition_penalty_range'] <= 0 else state['repetition_penalty_range']
|
2023-06-16 19:35:38 -04:00
|
|
|
if state['ban_eos_token']:
|
2023-06-17 17:00:10 -04:00
|
|
|
self.generator.disallow_tokens([self.tokenizer.eos_token_id])
|
|
|
|
else:
|
|
|
|
self.generator.disallow_tokens(None)
|
|
|
|
|
2023-09-15 17:27:27 -04:00
|
|
|
if state['custom_token_bans']:
|
|
|
|
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
|
|
|
|
if len(to_ban) > 0:
|
2023-09-16 08:42:38 -04:00
|
|
|
self.generator.disallow_tokens(to_ban)
|
2023-09-15 17:27:27 -04:00
|
|
|
|
2023-08-06 16:22:48 -04:00
|
|
|
# Case 1: no CFG
|
|
|
|
if state['guidance_scale'] == 1:
|
|
|
|
self.generator.end_beam_search()
|
|
|
|
|
|
|
|
# Tokenizing the input
|
2023-08-20 11:50:32 -04:00
|
|
|
ids = self.generator.tokenizer.encode(prompt, max_seq_len=self.model.config.max_seq_len)
|
2023-09-16 08:42:38 -04:00
|
|
|
if state['add_bos_token']:
|
|
|
|
ids = torch.cat(
|
|
|
|
[torch.tensor([[self.tokenizer.bos_token_id]]).to(ids.device),
|
|
|
|
ids], dim=1
|
|
|
|
).to(torch.int64)
|
2023-08-06 16:22:48 -04:00
|
|
|
ids = ids[:, -get_max_prompt_length(state):]
|
|
|
|
if state['auto_max_new_tokens']:
|
|
|
|
max_new_tokens = state['truncation_length'] - ids.shape[-1]
|
|
|
|
else:
|
|
|
|
max_new_tokens = state['max_new_tokens']
|
|
|
|
|
|
|
|
self.generator.gen_begin_reuse(ids)
|
|
|
|
initial_len = self.generator.sequence[0].shape[0]
|
|
|
|
has_leading_space = False
|
2023-07-07 12:09:23 -04:00
|
|
|
|
2023-08-06 16:22:48 -04:00
|
|
|
for i in range(max_new_tokens):
|
|
|
|
token = self.generator.gen_single_token()
|
|
|
|
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
|
|
|
|
has_leading_space = True
|
|
|
|
|
|
|
|
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
|
|
|
|
if has_leading_space:
|
|
|
|
decoded_text = ' ' + decoded_text
|
|
|
|
|
|
|
|
yield decoded_text
|
|
|
|
if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything:
|
|
|
|
break
|
|
|
|
|
|
|
|
# Case 2: CFG
|
2023-08-06 16:42:43 -04:00
|
|
|
# Copied from https://github.com/turboderp/exllama/blob/master/example_cfg.py
|
2023-08-02 14:01:29 -04:00
|
|
|
else:
|
2023-08-06 16:22:48 -04:00
|
|
|
alpha = state['guidance_scale']
|
|
|
|
prompts = [prompt, state['negative_prompt'] or '']
|
|
|
|
|
2023-09-16 08:42:38 -04:00
|
|
|
ids, mask = self.tokenizer.encode(
|
|
|
|
prompts,
|
|
|
|
return_mask=True,
|
|
|
|
max_seq_len=self.model.config.max_seq_len,
|
|
|
|
add_bos=state['add_bos_token']
|
|
|
|
)
|
2023-08-06 16:22:48 -04:00
|
|
|
if state['auto_max_new_tokens']:
|
|
|
|
max_new_tokens = state['truncation_length'] - ids[0].shape[-1]
|
|
|
|
else:
|
|
|
|
max_new_tokens = state['max_new_tokens']
|
|
|
|
|
|
|
|
self.generator.gen_begin(ids, mask=mask)
|
|
|
|
initial_len = self.generator.sequence[0].shape[0]
|
|
|
|
has_leading_space = False
|
|
|
|
|
|
|
|
for i in range(max_new_tokens):
|
|
|
|
logits = self.model.forward(self.generator.sequence[:, -1:], self.cache, input_mask=mask)
|
|
|
|
self.generator.apply_rep_penalty(logits)
|
|
|
|
|
|
|
|
logits = F.log_softmax(logits, dim=-1)
|
|
|
|
logits_mixed = alpha * logits[0] + (1 - alpha) * logits[1]
|
|
|
|
|
|
|
|
token, _ = self.generator.sample_current(logits_mixed)
|
|
|
|
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
|
|
|
|
has_leading_space = True
|
|
|
|
|
|
|
|
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
|
|
|
|
if has_leading_space:
|
|
|
|
decoded_text = ' ' + decoded_text
|
|
|
|
|
|
|
|
yield decoded_text
|
|
|
|
if token.item() == self.tokenizer.eos_token_id or shared.stop_everything:
|
|
|
|
break
|
|
|
|
|
|
|
|
batch_token = token.repeat(2, 1)
|
|
|
|
self.generator.gen_accept_token(batch_token)
|
2023-06-16 19:35:38 -04:00
|
|
|
|
2023-06-17 18:02:08 -04:00
|
|
|
def generate(self, prompt, state):
|
|
|
|
output = ''
|
|
|
|
for output in self.generate_with_streaming(prompt, state):
|
|
|
|
pass
|
|
|
|
|
|
|
|
return output
|