fix returned dtypes for LLaVA (#1547)

This commit is contained in:
Wojtab 2023-04-26 02:25:34 +02:00 committed by GitHub
parent 9b272bc8e5
commit 65beb51b0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -245,7 +245,9 @@ def tokenizer_modifier(state, prompt, input_ids, input_embeds):
prompt, input_ids, input_embeds, total_embedded = llava_embedder.forward(prompt, images, state) prompt, input_ids, input_embeds, total_embedded = llava_embedder.forward(prompt, images, state)
print(f'LLaVA - Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s') print(f'LLaVA - Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
return prompt, input_ids.unsqueeze(0).to(shared.model.device), input_embeds.unsqueeze(0).to(shared.model.device) return (prompt,
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))
def ui(): def ui():