xref: /aosp_15_r20/external/executorch/examples/models/llava/model.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# An ExecuTorch friendly implementation of Llava-1.5.
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport re
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Dict, Optional, Tuple
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerimport requests
14*523fa7a6SAndroid Build Coastguard Workerimport torch
15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.models.llama.source_transformation.sdpa import (
18*523fa7a6SAndroid Build Coastguard Worker    replace_sdpa_with_custom_op,
19*523fa7a6SAndroid Build Coastguard Worker)
20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.models.llava.image_util import prepare_image
21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.models.model_base import EagerModelBase
22*523fa7a6SAndroid Build Coastguard Workerfrom PIL import Image
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import Dim
25*523fa7a6SAndroid Build Coastguard Workerfrom torchvision.transforms.v2 import functional as F
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard Workerfrom transformers import (
28*523fa7a6SAndroid Build Coastguard Worker    AutoProcessor,
29*523fa7a6SAndroid Build Coastguard Worker    CLIPImageProcessor,
30*523fa7a6SAndroid Build Coastguard Worker    LlamaForCausalLM,
31*523fa7a6SAndroid Build Coastguard Worker    LlavaForConditionalGeneration,
32*523fa7a6SAndroid Build Coastguard Worker)
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Workerclass Llava(torch.nn.Module):
36*523fa7a6SAndroid Build Coastguard Worker    def __init__(
37*523fa7a6SAndroid Build Coastguard Worker        self,
38*523fa7a6SAndroid Build Coastguard Worker        llava_model: LlavaForConditionalGeneration,
39*523fa7a6SAndroid Build Coastguard Worker        image_processor: CLIPImageProcessor,
40*523fa7a6SAndroid Build Coastguard Worker        use_sdpa_with_kv_cache_op: bool = True,
41*523fa7a6SAndroid Build Coastguard Worker        max_seq_len: int = 768,
42*523fa7a6SAndroid Build Coastguard Worker    ):
43*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
44*523fa7a6SAndroid Build Coastguard Worker        self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
45*523fa7a6SAndroid Build Coastguard Worker        self.model_ = llava_model
46*523fa7a6SAndroid Build Coastguard Worker        self.image_processor = image_processor
47*523fa7a6SAndroid Build Coastguard Worker        self.vision_feature_layer = self.model_.config.vision_feature_layer
48*523fa7a6SAndroid Build Coastguard Worker        self.vision_feature_select_strategy = (
49*523fa7a6SAndroid Build Coastguard Worker            self.model_.config.vision_feature_select_strategy
50*523fa7a6SAndroid Build Coastguard Worker        )
51*523fa7a6SAndroid Build Coastguard Worker        self.text_model_args = ModelArgs(
52*523fa7a6SAndroid Build Coastguard Worker            use_kv_cache=True,
53*523fa7a6SAndroid Build Coastguard Worker            vocab_size=self.model_.config.text_config.vocab_size,
54*523fa7a6SAndroid Build Coastguard Worker            hidden_dim=self.model_.config.text_config.intermediate_size,
55*523fa7a6SAndroid Build Coastguard Worker            max_batch_size=1,  # doesn't work with default batch size 32
56*523fa7a6SAndroid Build Coastguard Worker            ffn_dim_multiplier=1,  # TODO: a hack to make rotary embedding happy
57*523fa7a6SAndroid Build Coastguard Worker            enable_dynamic_shape=True,  # allow parallel prefill
58*523fa7a6SAndroid Build Coastguard Worker            use_sdpa_with_kv_cache_op=use_sdpa_with_kv_cache_op,  # use sdpa_with_kv_cache op
59*523fa7a6SAndroid Build Coastguard Worker            use_hf_rope=True,
60*523fa7a6SAndroid Build Coastguard Worker            max_seq_len=max_seq_len,
61*523fa7a6SAndroid Build Coastguard Worker        )
62*523fa7a6SAndroid Build Coastguard Worker        self.text_model = Transformer(self.text_model_args)
63*523fa7a6SAndroid Build Coastguard Worker        # use custom op for SDPA.
64*523fa7a6SAndroid Build Coastguard Worker        if use_sdpa_with_kv_cache_op:
65*523fa7a6SAndroid Build Coastguard Worker            self.text_model = replace_sdpa_with_custom_op(self.text_model)
66*523fa7a6SAndroid Build Coastguard Worker        # load state dict
67*523fa7a6SAndroid Build Coastguard Worker        self.text_model.load_state_dict(
68*523fa7a6SAndroid Build Coastguard Worker            state_dict=self._translate_state_dict_for_text_model(),
69*523fa7a6SAndroid Build Coastguard Worker            strict=False,
70*523fa7a6SAndroid Build Coastguard Worker            assign=True,
71*523fa7a6SAndroid Build Coastguard Worker        )
72*523fa7a6SAndroid Build Coastguard Worker
73*523fa7a6SAndroid Build Coastguard Worker    def _translate_state_dict_for_text_model(self) -> Dict[str, Any]:
74*523fa7a6SAndroid Build Coastguard Worker        state_dict = self.model_.language_model.state_dict()
75*523fa7a6SAndroid Build Coastguard Worker        key_map = {
76*523fa7a6SAndroid Build Coastguard Worker            # fmt: off
77*523fa7a6SAndroid Build Coastguard Worker            r"model.layers.([0-9]+).self_attn.q_proj.": r"layers.\1.attention.wq.",
78*523fa7a6SAndroid Build Coastguard Worker            r"model.layers.([0-9]+).self_attn.k_proj.": r"layers.\1.attention.wk.",
79*523fa7a6SAndroid Build Coastguard Worker            r"model.layers.([0-9]+).self_attn.v_proj.": r"layers.\1.attention.wv.",
80*523fa7a6SAndroid Build Coastguard Worker            r"model.layers.([0-9]+).self_attn.o_proj.": r"layers.\1.attention.wo.",
81*523fa7a6SAndroid Build Coastguard Worker            r"model.layers.([0-9]+).input_layernorm.": r"layers.\1.attention_norm.",
82*523fa7a6SAndroid Build Coastguard Worker            r"model.layers.([0-9]+).mlp.gate_proj.": r"layers.\1.feed_forward.w1.",
83*523fa7a6SAndroid Build Coastguard Worker            r"model.layers.([0-9]+).mlp.down_proj.": r"layers.\1.feed_forward.w2.",
84*523fa7a6SAndroid Build Coastguard Worker            r"model.layers.([0-9]+).mlp.up_proj.": r"layers.\1.feed_forward.w3.",
85*523fa7a6SAndroid Build Coastguard Worker            r"model.layers.([0-9]+).post_attention_layernorm.": r"layers.\1.ffn_norm.",
86*523fa7a6SAndroid Build Coastguard Worker            r"model.norm.": r"norm.",
87*523fa7a6SAndroid Build Coastguard Worker            # r"model.embed_tokens.": r"tok_embeddings.", # load separately
88*523fa7a6SAndroid Build Coastguard Worker            r"lm_head.": r"output.",
89*523fa7a6SAndroid Build Coastguard Worker            # fmt: on
90*523fa7a6SAndroid Build Coastguard Worker        }
91*523fa7a6SAndroid Build Coastguard Worker
92*523fa7a6SAndroid Build Coastguard Worker        new_state_dict = {}
93*523fa7a6SAndroid Build Coastguard Worker
94*523fa7a6SAndroid Build Coastguard Worker        def get_new_key(old_key: str) -> str:
95*523fa7a6SAndroid Build Coastguard Worker            for old_pattern, replacement in key_map.items():
96*523fa7a6SAndroid Build Coastguard Worker                if (new_key := re.sub(old_pattern, replacement, old_key)) != old_key:
97*523fa7a6SAndroid Build Coastguard Worker                    return new_key
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker            return old_key
100*523fa7a6SAndroid Build Coastguard Worker
101*523fa7a6SAndroid Build Coastguard Worker        # Convert module keys from hf transformer to Llama transformer.
102*523fa7a6SAndroid Build Coastguard Worker        for old_key in state_dict.keys():
103*523fa7a6SAndroid Build Coastguard Worker            new_key = get_new_key(old_key)
104*523fa7a6SAndroid Build Coastguard Worker
105*523fa7a6SAndroid Build Coastguard Worker            new_state_dict[new_key] = state_dict[old_key]
106*523fa7a6SAndroid Build Coastguard Worker
107*523fa7a6SAndroid Build Coastguard Worker        return new_state_dict
108*523fa7a6SAndroid Build Coastguard Worker
109*523fa7a6SAndroid Build Coastguard Worker    def _feature_select(self, image_outputs):
110*523fa7a6SAndroid Build Coastguard Worker        selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
111*523fa7a6SAndroid Build Coastguard Worker
112*523fa7a6SAndroid Build Coastguard Worker        if self.vision_feature_select_strategy == "default":
113*523fa7a6SAndroid Build Coastguard Worker            selected_image_feature = selected_image_feature[:, 1:]
114*523fa7a6SAndroid Build Coastguard Worker        elif self.vision_feature_select_strategy == "full":
115*523fa7a6SAndroid Build Coastguard Worker            selected_image_feature = selected_image_feature
116*523fa7a6SAndroid Build Coastguard Worker        else:
117*523fa7a6SAndroid Build Coastguard Worker            raise ValueError(
118*523fa7a6SAndroid Build Coastguard Worker                f"Unexpected select feature: {self.vision_feature_select_strategy}"
119*523fa7a6SAndroid Build Coastguard Worker            )
120*523fa7a6SAndroid Build Coastguard Worker        return selected_image_feature
121*523fa7a6SAndroid Build Coastguard Worker
122*523fa7a6SAndroid Build Coastguard Worker    def get_model(self):
123*523fa7a6SAndroid Build Coastguard Worker        return self.model_.get_model()
124*523fa7a6SAndroid Build Coastguard Worker
125*523fa7a6SAndroid Build Coastguard Worker    def embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
126*523fa7a6SAndroid Build Coastguard Worker        return self.model_.language_model.model.embed_tokens(tokens)
127*523fa7a6SAndroid Build Coastguard Worker
128*523fa7a6SAndroid Build Coastguard Worker    def encode_images(self, images: torch.Tensor) -> torch.Tensor:
129*523fa7a6SAndroid Build Coastguard Worker        images = images.to(dtype=self.model_.dtype)
130*523fa7a6SAndroid Build Coastguard Worker        if type(images) is list:
131*523fa7a6SAndroid Build Coastguard Worker            image_features = []
132*523fa7a6SAndroid Build Coastguard Worker            for image in images:
133*523fa7a6SAndroid Build Coastguard Worker                image_forward_out = self.model_.vision_tower(
134*523fa7a6SAndroid Build Coastguard Worker                    image.to(
135*523fa7a6SAndroid Build Coastguard Worker                        device=self.model_.device, dtype=self.model_.dtype
136*523fa7a6SAndroid Build Coastguard Worker                    ).unsqueeze(0),
137*523fa7a6SAndroid Build Coastguard Worker                    output_hidden_states=True,
138*523fa7a6SAndroid Build Coastguard Worker                )
139*523fa7a6SAndroid Build Coastguard Worker                image_feature = self._feature_select(image_forward_out).to(image.dtype)
140*523fa7a6SAndroid Build Coastguard Worker                image_features.append(image_feature)
141*523fa7a6SAndroid Build Coastguard Worker        else:
142*523fa7a6SAndroid Build Coastguard Worker            image_forward_outs = self.model_.vision_tower(
143*523fa7a6SAndroid Build Coastguard Worker                images.to(device=self.model_.device, dtype=self.model_.dtype),
144*523fa7a6SAndroid Build Coastguard Worker                output_hidden_states=True,
145*523fa7a6SAndroid Build Coastguard Worker            )
146*523fa7a6SAndroid Build Coastguard Worker            image_features = self._feature_select(image_forward_outs).to(images.dtype)
147*523fa7a6SAndroid Build Coastguard Worker        image_features = self.model_.multi_modal_projector(image_features)
148*523fa7a6SAndroid Build Coastguard Worker        return image_features
149*523fa7a6SAndroid Build Coastguard Worker
150*523fa7a6SAndroid Build Coastguard Worker    def image_preprocess(self, img: torch.Tensor) -> torch.Tensor:
151*523fa7a6SAndroid Build Coastguard Worker        target_h = self.image_processor.crop_size["height"]
152*523fa7a6SAndroid Build Coastguard Worker        target_w = self.image_processor.crop_size["width"]
153*523fa7a6SAndroid Build Coastguard Worker        # pad the image with median rgb value, to make a square
154*523fa7a6SAndroid Build Coastguard Worker        l_pad = (target_w - img.shape[2]) // 2
155*523fa7a6SAndroid Build Coastguard Worker        t_pad = (target_h - img.shape[1]) // 2
156*523fa7a6SAndroid Build Coastguard Worker        # ceil division
157*523fa7a6SAndroid Build Coastguard Worker        r_pad = -((target_w - img.shape[2]) // -2)
158*523fa7a6SAndroid Build Coastguard Worker        b_pad = -((target_h - img.shape[1]) // -2)
159*523fa7a6SAndroid Build Coastguard Worker
160*523fa7a6SAndroid Build Coastguard Worker        torch._check(l_pad >= 0)
161*523fa7a6SAndroid Build Coastguard Worker        torch._check(t_pad >= 0)
162*523fa7a6SAndroid Build Coastguard Worker        torch._check(r_pad >= 0)
163*523fa7a6SAndroid Build Coastguard Worker        torch._check(b_pad >= 0)
164*523fa7a6SAndroid Build Coastguard Worker
165*523fa7a6SAndroid Build Coastguard Worker        # This is different from the original implementation, due to export limitations.
166*523fa7a6SAndroid Build Coastguard Worker        resized = torch.nn.functional.pad(
167*523fa7a6SAndroid Build Coastguard Worker            img,
168*523fa7a6SAndroid Build Coastguard Worker            (l_pad, r_pad, t_pad, b_pad),
169*523fa7a6SAndroid Build Coastguard Worker        )
170*523fa7a6SAndroid Build Coastguard Worker        # originally:
171*523fa7a6SAndroid Build Coastguard Worker        # resized = F.pad(
172*523fa7a6SAndroid Build Coastguard Worker        #     img,
173*523fa7a6SAndroid Build Coastguard Worker        #     padding=(l_pad, t_pad, r_pad, b_pad),
174*523fa7a6SAndroid Build Coastguard Worker        #     fill=tuple(int(x * 255) for x in self.image_mean),
175*523fa7a6SAndroid Build Coastguard Worker        # )
176*523fa7a6SAndroid Build Coastguard Worker
177*523fa7a6SAndroid Build Coastguard Worker        # TODO: implement _upsample_bicubic_aa.out in portable kernel library.
178*523fa7a6SAndroid Build Coastguard Worker        # here padded shape should be max(h, w) x max(h, w)
179*523fa7a6SAndroid Build Coastguard Worker        # skipping resize for now due to missing _upsample_bicubic_aa kernel in portable
180*523fa7a6SAndroid Build Coastguard Worker        # resized = resize(
181*523fa7a6SAndroid Build Coastguard Worker        #     padded,
182*523fa7a6SAndroid Build Coastguard Worker        #     size=[
183*523fa7a6SAndroid Build Coastguard Worker        #         self.image_processor.crop_size["height"],
184*523fa7a6SAndroid Build Coastguard Worker        #         self.image_processor.crop_size["width"],
185*523fa7a6SAndroid Build Coastguard Worker        #     ],
186*523fa7a6SAndroid Build Coastguard Worker        #     interpolation="bicubic",
187*523fa7a6SAndroid Build Coastguard Worker        # )
188*523fa7a6SAndroid Build Coastguard Worker        # torch._check(resized.size(1) == self.config.crop_size["height"])
189*523fa7a6SAndroid Build Coastguard Worker        # torch._check(resized.size(2) == self.config.crop_size["width"])
190*523fa7a6SAndroid Build Coastguard Worker        # print(resized.shape)
191*523fa7a6SAndroid Build Coastguard Worker        # cropped = F.center_crop(img, output_size=[w, w])
192*523fa7a6SAndroid Build Coastguard Worker        # print(cropped.shape)
193*523fa7a6SAndroid Build Coastguard Worker        scaled = resized * self.image_processor.rescale_factor
194*523fa7a6SAndroid Build Coastguard Worker        # print(scaled)
195*523fa7a6SAndroid Build Coastguard Worker        normed = F.normalize(
196*523fa7a6SAndroid Build Coastguard Worker            scaled, self.image_processor.image_mean, self.image_processor.image_std
197*523fa7a6SAndroid Build Coastguard Worker        )
198*523fa7a6SAndroid Build Coastguard Worker        # print(normed)
199*523fa7a6SAndroid Build Coastguard Worker        return normed.unsqueeze(0)
200*523fa7a6SAndroid Build Coastguard Worker
201*523fa7a6SAndroid Build Coastguard Worker    def step(
202*523fa7a6SAndroid Build Coastguard Worker        self, token: torch.Tensor, input_pos: Optional[torch.Tensor] = None
203*523fa7a6SAndroid Build Coastguard Worker    ) -> torch.Tensor:
204*523fa7a6SAndroid Build Coastguard Worker        """Input is one token. Return logits for next token."""
205*523fa7a6SAndroid Build Coastguard Worker        token_embeds = self.embed_tokens(token).unsqueeze(0)
206*523fa7a6SAndroid Build Coastguard Worker        return self.text_model.forward(None, input_pos, token_embeds)
207*523fa7a6SAndroid Build Coastguard Worker
208*523fa7a6SAndroid Build Coastguard Worker    def image_embedding(self, images: torch.Tensor) -> torch.Tensor:
209*523fa7a6SAndroid Build Coastguard Worker        preprocessed_img = self.image_preprocess(images)
210*523fa7a6SAndroid Build Coastguard Worker        return self.encode_images(preprocessed_img)
211*523fa7a6SAndroid Build Coastguard Worker
212*523fa7a6SAndroid Build Coastguard Worker    def prefill_embedding(
213*523fa7a6SAndroid Build Coastguard Worker        self,
214*523fa7a6SAndroid Build Coastguard Worker        prompt_before_image: torch.Tensor,
215*523fa7a6SAndroid Build Coastguard Worker        images: torch.Tensor,
216*523fa7a6SAndroid Build Coastguard Worker        prompt_after_image: torch.Tensor,
217*523fa7a6SAndroid Build Coastguard Worker    ) -> torch.Tensor:
218*523fa7a6SAndroid Build Coastguard Worker        image_embeds = self.image_embedding(images)
219*523fa7a6SAndroid Build Coastguard Worker        embeds_before_img = self.embed_tokens(prompt_before_image)
220*523fa7a6SAndroid Build Coastguard Worker        embeds_after_img = self.embed_tokens(prompt_after_image)
221*523fa7a6SAndroid Build Coastguard Worker        result = torch.cat((embeds_before_img, image_embeds, embeds_after_img), dim=1)
222*523fa7a6SAndroid Build Coastguard Worker        return result
223*523fa7a6SAndroid Build Coastguard Worker
224*523fa7a6SAndroid Build Coastguard Worker    # prefill using the in house text_model of llama transformer
225*523fa7a6SAndroid Build Coastguard Worker    def prefill(
226*523fa7a6SAndroid Build Coastguard Worker        self,
227*523fa7a6SAndroid Build Coastguard Worker        prompt_before_image: torch.Tensor,
228*523fa7a6SAndroid Build Coastguard Worker        images: torch.Tensor,
229*523fa7a6SAndroid Build Coastguard Worker        prompt_after_image: torch.Tensor,
230*523fa7a6SAndroid Build Coastguard Worker    ) -> Tuple[int, torch.Tensor]:
231*523fa7a6SAndroid Build Coastguard Worker        """Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead."""
232*523fa7a6SAndroid Build Coastguard Worker        embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image)
233*523fa7a6SAndroid Build Coastguard Worker        # returns the prefilled token length too, because the text model generates one logits in each forward call.
234*523fa7a6SAndroid Build Coastguard Worker        return embeds.shape[1], self.text_model.forward(None, torch.tensor([0]), embeds)
235*523fa7a6SAndroid Build Coastguard Worker
236*523fa7a6SAndroid Build Coastguard Worker    # reference prefill using the text model in HF
237*523fa7a6SAndroid Build Coastguard Worker    def prefill_ref(
238*523fa7a6SAndroid Build Coastguard Worker        self,
239*523fa7a6SAndroid Build Coastguard Worker        prompt_before_image: torch.Tensor,
240*523fa7a6SAndroid Build Coastguard Worker        images: torch.Tensor,
241*523fa7a6SAndroid Build Coastguard Worker        prompt_after_image: torch.Tensor,
242*523fa7a6SAndroid Build Coastguard Worker    ) -> torch.Tensor:
243*523fa7a6SAndroid Build Coastguard Worker        """Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead."""
244*523fa7a6SAndroid Build Coastguard Worker        embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image)
245*523fa7a6SAndroid Build Coastguard Worker        return LlamaForCausalLM.forward(
246*523fa7a6SAndroid Build Coastguard Worker            self.model_.language_model,
247*523fa7a6SAndroid Build Coastguard Worker            inputs_embeds=embeds,
248*523fa7a6SAndroid Build Coastguard Worker            return_dict=False,
249*523fa7a6SAndroid Build Coastguard Worker            use_cache=False,
250*523fa7a6SAndroid Build Coastguard Worker            output_hidden_states=False,
251*523fa7a6SAndroid Build Coastguard Worker        )
252*523fa7a6SAndroid Build Coastguard Worker
253*523fa7a6SAndroid Build Coastguard Worker    def forward(
254*523fa7a6SAndroid Build Coastguard Worker        self,
255*523fa7a6SAndroid Build Coastguard Worker        images: torch.Tensor,
256*523fa7a6SAndroid Build Coastguard Worker    ) -> torch.Tensor:
257*523fa7a6SAndroid Build Coastguard Worker        return self.image_embedding(images)
258*523fa7a6SAndroid Build Coastguard Worker
259*523fa7a6SAndroid Build Coastguard Worker
260*523fa7a6SAndroid Build Coastguard Workerclass LlavaModel(EagerModelBase):
261*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, use_sdpa_with_kv_cache_op=True, max_seq_len=768):
262*523fa7a6SAndroid Build Coastguard Worker        self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
263*523fa7a6SAndroid Build Coastguard Worker        self.max_seq_len = max_seq_len
264*523fa7a6SAndroid Build Coastguard Worker        self.processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
265*523fa7a6SAndroid Build Coastguard Worker        self.tokenizer = self.processor.tokenizer
266*523fa7a6SAndroid Build Coastguard Worker        self.image_processor = self.processor.image_processor
267*523fa7a6SAndroid Build Coastguard Worker        self.model = LlavaForConditionalGeneration.from_pretrained(
268*523fa7a6SAndroid Build Coastguard Worker            "llava-hf/llava-1.5-7b-hf",
269*523fa7a6SAndroid Build Coastguard Worker            device_map="cpu",
270*523fa7a6SAndroid Build Coastguard Worker        )
271*523fa7a6SAndroid Build Coastguard Worker        self.image = Image.open(
272*523fa7a6SAndroid Build Coastguard Worker            requests.get(
273*523fa7a6SAndroid Build Coastguard Worker                "https://llava-vl.github.io/static/images/view.jpg", stream=True
274*523fa7a6SAndroid Build Coastguard Worker            ).raw
275*523fa7a6SAndroid Build Coastguard Worker        )
276*523fa7a6SAndroid Build Coastguard Worker        self.prompt = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>
277*523fa7a6SAndroid Build Coastguard WorkerWhat are the things I should be cautious about when I visit here? ASSISTANT:"""
278*523fa7a6SAndroid Build Coastguard Worker        self.model_name = "llava-1.5-7b-hf"
279*523fa7a6SAndroid Build Coastguard Worker        # set input to None and initialize them lazily
280*523fa7a6SAndroid Build Coastguard Worker        self.input = None
281*523fa7a6SAndroid Build Coastguard Worker        self.resized_image = None
282*523fa7a6SAndroid Build Coastguard Worker
283*523fa7a6SAndroid Build Coastguard Worker    def get_eager_model(self):
284*523fa7a6SAndroid Build Coastguard Worker        model = Llava(
285*523fa7a6SAndroid Build Coastguard Worker            self.model,
286*523fa7a6SAndroid Build Coastguard Worker            self.image_processor,
287*523fa7a6SAndroid Build Coastguard Worker            self.use_sdpa_with_kv_cache_op,
288*523fa7a6SAndroid Build Coastguard Worker            self.max_seq_len,
289*523fa7a6SAndroid Build Coastguard Worker        )
290*523fa7a6SAndroid Build Coastguard Worker        model.to(dtype=torch.float32)
291*523fa7a6SAndroid Build Coastguard Worker        return model
292*523fa7a6SAndroid Build Coastguard Worker
293*523fa7a6SAndroid Build Coastguard Worker    def get_example_inputs(self):
294*523fa7a6SAndroid Build Coastguard Worker        """Returns a resized image as input to model.forward()."""
295*523fa7a6SAndroid Build Coastguard Worker        if self.resized_image:
296*523fa7a6SAndroid Build Coastguard Worker            return self.resized_image
297*523fa7a6SAndroid Build Coastguard Worker        resized = prepare_image(
298*523fa7a6SAndroid Build Coastguard Worker            self.image,
299*523fa7a6SAndroid Build Coastguard Worker            self.image_processor.crop_size["height"],
300*523fa7a6SAndroid Build Coastguard Worker            self.image_processor.crop_size["width"],
301*523fa7a6SAndroid Build Coastguard Worker        )
302*523fa7a6SAndroid Build Coastguard Worker        self.resized_image = (resized,)
303*523fa7a6SAndroid Build Coastguard Worker        return self.resized_image
304*523fa7a6SAndroid Build Coastguard Worker
305*523fa7a6SAndroid Build Coastguard Worker    def get_inputs_for_prefill(self):
306*523fa7a6SAndroid Build Coastguard Worker        """Returns prompts as well as image."""
307*523fa7a6SAndroid Build Coastguard Worker        if self.input:
308*523fa7a6SAndroid Build Coastguard Worker            return self.input
309*523fa7a6SAndroid Build Coastguard Worker        self.input_ids = self.tokenizer.encode(self.prompt, return_tensors="pt").cpu()
310*523fa7a6SAndroid Build Coastguard Worker        index = torch.where(self.input_ids == self.model.config.image_token_index)[1]
311*523fa7a6SAndroid Build Coastguard Worker        self.prompt_before_image = self.input_ids[:, :index]
312*523fa7a6SAndroid Build Coastguard Worker        # print(prompt_before_image.shape)
313*523fa7a6SAndroid Build Coastguard Worker        self.prompt_after_image = self.input_ids[:, index + 1 :]
314*523fa7a6SAndroid Build Coastguard Worker        # print(prompt_after_image.shape)
315*523fa7a6SAndroid Build Coastguard Worker        self.input = (
316*523fa7a6SAndroid Build Coastguard Worker            self.prompt_before_image,
317*523fa7a6SAndroid Build Coastguard Worker            *self.get_example_inputs(),
318*523fa7a6SAndroid Build Coastguard Worker            self.prompt_after_image,
319*523fa7a6SAndroid Build Coastguard Worker        )
320*523fa7a6SAndroid Build Coastguard Worker        return self.input
321*523fa7a6SAndroid Build Coastguard Worker
322*523fa7a6SAndroid Build Coastguard Worker    def get_dynamic_shapes(self):
323*523fa7a6SAndroid Build Coastguard Worker        return self._get_image_dynamic_shapes()
324*523fa7a6SAndroid Build Coastguard Worker
325*523fa7a6SAndroid Build Coastguard Worker    def _get_image_dynamic_shapes(self):
326*523fa7a6SAndroid Build Coastguard Worker        # only support even number of height and width for now
327*523fa7a6SAndroid Build Coastguard Worker        _height = Dim(
328*523fa7a6SAndroid Build Coastguard Worker            "_height", min=1, max=self.image_processor.crop_size["height"] // 2
329*523fa7a6SAndroid Build Coastguard Worker        )
330*523fa7a6SAndroid Build Coastguard Worker        _width = Dim("_width", min=1, max=self.image_processor.crop_size["width"] // 2)
331*523fa7a6SAndroid Build Coastguard Worker        height = 2 * _height
332*523fa7a6SAndroid Build Coastguard Worker        width = 2 * _width
333*523fa7a6SAndroid Build Coastguard Worker        dynamic_shapes = [{1: height, 2: width}]
334*523fa7a6SAndroid Build Coastguard Worker        return dynamic_shapes
335*523fa7a6SAndroid Build Coastguard Worker
336*523fa7a6SAndroid Build Coastguard Worker    def _get_prompt_dynamic_shapes(self):
337*523fa7a6SAndroid Build Coastguard Worker        dim = torch.export.Dim("token_dim", min=2, max=self.max_seq_len)
338*523fa7a6SAndroid Build Coastguard Worker        text_model_dynamic_shapes = ({0: 1}, {1: dim})
339*523fa7a6SAndroid Build Coastguard Worker        return text_model_dynamic_shapes
340