Fix exllama_hf gibbersh above 2048 context, and works >5000 context. (#2913)

This commit is contained in:
Panchovix 2023-06-28 11:36:07 -04:00 committed by GitHub
parent 63770c0643
commit 37a16d23a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -54,7 +54,15 @@ class ExllamaHF(PreTrainedModel):
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
if cache is None: if cache is None:
cache = ExLlamaCache(self.ex_model) cache = ExLlamaCache(self.ex_model)
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True, lora=self.lora)
nseq = seq[:-1]
for seqs in [nseq[i : i + 2048] for i in range(0, len(nseq), 2048)]:
self.ex_model.forward(
torch.tensor([seqs], dtype=torch.long),
cache,
preprocess_only=True,
lora=self.lora,
)
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device) logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device)