diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index e12a0717..952d7172 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -108,10 +108,10 @@ class Exllamav2HF(PreTrainedModel): if len(seq_tensor) > 1: self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras) - logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device) + logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device).float() else: ex_cache.current_seq_len = 0 - logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras) + logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras).float() if is_negative: self.past_seq_negative = seq_tensor