diff --git a/extensions/llava/script.py b/extensions/llava/script.py index d48e35fa..ba951f20 100644 --- a/extensions/llava/script.py +++ b/extensions/llava/script.py @@ -245,7 +245,9 @@ def tokenizer_modifier(state, prompt, input_ids, input_embeds): prompt, input_ids, input_embeds, total_embedded = llava_embedder.forward(prompt, images, state) print(f'LLaVA - Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s') - return prompt, input_ids.unsqueeze(0).to(shared.model.device), input_embeds.unsqueeze(0).to(shared.model.device) + return (prompt, + input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64), + input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype)) def ui():