diff --git a/extensions/multimodal/README.md b/extensions/multimodal/README.md index 10bbc7f5..50681034 100644 --- a/extensions/multimodal/README.md +++ b/extensions/multimodal/README.md @@ -11,10 +11,10 @@ https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b To run this extension, download a LLM that supports multimodality, and then start server.py with the appropriate `--multimodal-pipeline` argument. Examples: ``` -python server.py --model wojtab_llava-7b-v0-4bit-128g --multimodal-pipeline llava-7b --chat -python3 server.py --model wojtab_llava-13b-v0-4bit-128g --multimodal-pipeline llava-13b --chat -python server.py --model anon8231489123_vicuna-13b-GPTQ-4bit-128g --multimodal-pipeline minigpt4-13b --chat -python server.py --model llama-7b-4bit --multimodal-pipeline minigpt4-7b --chat +python server.py --model wojtab_llava-7b-v0-4bit-128g --multimodal-pipeline llava-7b +python3 server.py --model wojtab_llava-13b-v0-4bit-128g --multimodal-pipeline llava-13b +python server.py --model anon8231489123_vicuna-13b-GPTQ-4bit-128g --multimodal-pipeline minigpt4-13b +python server.py --model llama-7b-4bit --multimodal-pipeline minigpt4-7b ``` There is built-in support for LLaVA-v0-13B and LLaVA-v0-7b. To install `minigpt4`: diff --git a/extensions/multimodal/pipelines/llava/llava.py b/extensions/multimodal/pipelines/llava/llava.py index eca2be50..306ab227 100644 --- a/extensions/multimodal/pipelines/llava/llava.py +++ b/extensions/multimodal/pipelines/llava/llava.py @@ -56,10 +56,13 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline): @staticmethod def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor: - if hasattr(shared.model.model, 'embed_tokens'): - func = shared.model.model.embed_tokens + for attr in ['', 'model', 'model.model', 'model.model.model']: + tmp = getattr(shared.model, attr, None) if attr != '' else shared.model + if tmp is not None and hasattr(tmp, 'embed_tokens'): + func = tmp.embed_tokens + break else: - func = shared.model.model.model.embed_tokens # AutoGPTQ case + raise ValueError('The embed_tokens method has not been found for this loader.') return func(input_ids).to(shared.model.device, dtype=shared.model.dtype)