ExLlamav2_HF: Convert logits to FP32 (#4310)

This commit is contained in:
turboderp 2023-10-19 04:16:05 +02:00 committed by GitHub
parent c0ffb77fd8
commit ae8cd449ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -108,10 +108,10 @@ class Exllamav2HF(PreTrainedModel):
if len(seq_tensor) > 1:
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device)
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device).float()
else:
ex_cache.current_seq_len = 0
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras)
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras).float()
if is_negative:
self.past_seq_negative = seq_tensor