From 65beb51b0b9d3ca22c947b587eed06c7f216338a Mon Sep 17 00:00:00 2001 From: Wojtab Date: Wed, 26 Apr 2023 02:25:34 +0200 Subject: [PATCH] fix returned dtypes for LLaVA (#1547) --- extensions/llava/script.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/extensions/llava/script.py b/extensions/llava/script.py index d48e35fa..ba951f20 100644 --- a/extensions/llava/script.py +++ b/extensions/llava/script.py @@ -245,7 +245,9 @@ def tokenizer_modifier(state, prompt, input_ids, input_embeds): prompt, input_ids, input_embeds, total_embedded = llava_embedder.forward(prompt, images, state) print(f'LLaVA - Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s') - return prompt, input_ids.unsqueeze(0).to(shared.model.device), input_embeds.unsqueeze(0).to(shared.model.device) + return (prompt, + input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64), + input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype)) def ui():