1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# An ExecuTorch friendly implementation of Llava-1.5. 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport re 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Dict, Optional, Tuple 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerimport requests 14*523fa7a6SAndroid Build Coastguard Workerimport torch 15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.models.llama.source_transformation.sdpa import ( 18*523fa7a6SAndroid Build Coastguard Worker replace_sdpa_with_custom_op, 19*523fa7a6SAndroid Build Coastguard Worker) 20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.models.llava.image_util import prepare_image 21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.models.model_base import EagerModelBase 22*523fa7a6SAndroid Build Coastguard Workerfrom PIL import Image 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import Dim 25*523fa7a6SAndroid Build Coastguard Workerfrom torchvision.transforms.v2 import functional as F 26*523fa7a6SAndroid Build Coastguard Worker 27*523fa7a6SAndroid Build Coastguard Workerfrom transformers import ( 28*523fa7a6SAndroid Build Coastguard Worker AutoProcessor, 29*523fa7a6SAndroid Build Coastguard Worker CLIPImageProcessor, 30*523fa7a6SAndroid Build Coastguard Worker LlamaForCausalLM, 31*523fa7a6SAndroid Build Coastguard Worker LlavaForConditionalGeneration, 32*523fa7a6SAndroid Build Coastguard Worker) 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Worker 35*523fa7a6SAndroid Build Coastguard Workerclass Llava(torch.nn.Module): 36*523fa7a6SAndroid Build Coastguard Worker def __init__( 37*523fa7a6SAndroid Build Coastguard Worker self, 38*523fa7a6SAndroid Build Coastguard Worker llava_model: LlavaForConditionalGeneration, 39*523fa7a6SAndroid Build Coastguard Worker image_processor: CLIPImageProcessor, 40*523fa7a6SAndroid Build Coastguard Worker use_sdpa_with_kv_cache_op: bool = True, 41*523fa7a6SAndroid Build Coastguard Worker max_seq_len: int = 768, 42*523fa7a6SAndroid Build Coastguard Worker ): 43*523fa7a6SAndroid Build Coastguard Worker super().__init__() 44*523fa7a6SAndroid Build Coastguard Worker self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op 45*523fa7a6SAndroid Build Coastguard Worker self.model_ = llava_model 46*523fa7a6SAndroid Build Coastguard Worker self.image_processor = image_processor 47*523fa7a6SAndroid Build Coastguard Worker self.vision_feature_layer = self.model_.config.vision_feature_layer 48*523fa7a6SAndroid Build Coastguard Worker self.vision_feature_select_strategy = ( 49*523fa7a6SAndroid Build Coastguard Worker self.model_.config.vision_feature_select_strategy 50*523fa7a6SAndroid Build Coastguard Worker ) 51*523fa7a6SAndroid Build Coastguard Worker self.text_model_args = ModelArgs( 52*523fa7a6SAndroid Build Coastguard Worker use_kv_cache=True, 53*523fa7a6SAndroid Build Coastguard Worker vocab_size=self.model_.config.text_config.vocab_size, 54*523fa7a6SAndroid Build Coastguard Worker hidden_dim=self.model_.config.text_config.intermediate_size, 55*523fa7a6SAndroid Build Coastguard Worker max_batch_size=1, # doesn't work with default batch size 32 56*523fa7a6SAndroid Build Coastguard Worker ffn_dim_multiplier=1, # TODO: a hack to make rotary embedding happy 57*523fa7a6SAndroid Build Coastguard Worker enable_dynamic_shape=True, # allow parallel prefill 58*523fa7a6SAndroid Build Coastguard Worker use_sdpa_with_kv_cache_op=use_sdpa_with_kv_cache_op, # use sdpa_with_kv_cache op 59*523fa7a6SAndroid Build Coastguard Worker use_hf_rope=True, 60*523fa7a6SAndroid Build Coastguard Worker max_seq_len=max_seq_len, 61*523fa7a6SAndroid Build Coastguard Worker ) 62*523fa7a6SAndroid Build Coastguard Worker self.text_model = Transformer(self.text_model_args) 63*523fa7a6SAndroid Build Coastguard Worker # use custom op for SDPA. 64*523fa7a6SAndroid Build Coastguard Worker if use_sdpa_with_kv_cache_op: 65*523fa7a6SAndroid Build Coastguard Worker self.text_model = replace_sdpa_with_custom_op(self.text_model) 66*523fa7a6SAndroid Build Coastguard Worker # load state dict 67*523fa7a6SAndroid Build Coastguard Worker self.text_model.load_state_dict( 68*523fa7a6SAndroid Build Coastguard Worker state_dict=self._translate_state_dict_for_text_model(), 69*523fa7a6SAndroid Build Coastguard Worker strict=False, 70*523fa7a6SAndroid Build Coastguard Worker assign=True, 71*523fa7a6SAndroid Build Coastguard Worker ) 72*523fa7a6SAndroid Build Coastguard Worker 73*523fa7a6SAndroid Build Coastguard Worker def _translate_state_dict_for_text_model(self) -> Dict[str, Any]: 74*523fa7a6SAndroid Build Coastguard Worker state_dict = self.model_.language_model.state_dict() 75*523fa7a6SAndroid Build Coastguard Worker key_map = { 76*523fa7a6SAndroid Build Coastguard Worker # fmt: off 77*523fa7a6SAndroid Build Coastguard Worker r"model.layers.([0-9]+).self_attn.q_proj.": r"layers.\1.attention.wq.", 78*523fa7a6SAndroid Build Coastguard Worker r"model.layers.([0-9]+).self_attn.k_proj.": r"layers.\1.attention.wk.", 79*523fa7a6SAndroid Build Coastguard Worker r"model.layers.([0-9]+).self_attn.v_proj.": r"layers.\1.attention.wv.", 80*523fa7a6SAndroid Build Coastguard Worker r"model.layers.([0-9]+).self_attn.o_proj.": r"layers.\1.attention.wo.", 81*523fa7a6SAndroid Build Coastguard Worker r"model.layers.([0-9]+).input_layernorm.": r"layers.\1.attention_norm.", 82*523fa7a6SAndroid Build Coastguard Worker r"model.layers.([0-9]+).mlp.gate_proj.": r"layers.\1.feed_forward.w1.", 83*523fa7a6SAndroid Build Coastguard Worker r"model.layers.([0-9]+).mlp.down_proj.": r"layers.\1.feed_forward.w2.", 84*523fa7a6SAndroid Build Coastguard Worker r"model.layers.([0-9]+).mlp.up_proj.": r"layers.\1.feed_forward.w3.", 85*523fa7a6SAndroid Build Coastguard Worker r"model.layers.([0-9]+).post_attention_layernorm.": r"layers.\1.ffn_norm.", 86*523fa7a6SAndroid Build Coastguard Worker r"model.norm.": r"norm.", 87*523fa7a6SAndroid Build Coastguard Worker # r"model.embed_tokens.": r"tok_embeddings.", # load separately 88*523fa7a6SAndroid Build Coastguard Worker r"lm_head.": r"output.", 89*523fa7a6SAndroid Build Coastguard Worker # fmt: on 90*523fa7a6SAndroid Build Coastguard Worker } 91*523fa7a6SAndroid Build Coastguard Worker 92*523fa7a6SAndroid Build Coastguard Worker new_state_dict = {} 93*523fa7a6SAndroid Build Coastguard Worker 94*523fa7a6SAndroid Build Coastguard Worker def get_new_key(old_key: str) -> str: 95*523fa7a6SAndroid Build Coastguard Worker for old_pattern, replacement in key_map.items(): 96*523fa7a6SAndroid Build Coastguard Worker if (new_key := re.sub(old_pattern, replacement, old_key)) != old_key: 97*523fa7a6SAndroid Build Coastguard Worker return new_key 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Worker return old_key 100*523fa7a6SAndroid Build Coastguard Worker 101*523fa7a6SAndroid Build Coastguard Worker # Convert module keys from hf transformer to Llama transformer. 102*523fa7a6SAndroid Build Coastguard Worker for old_key in state_dict.keys(): 103*523fa7a6SAndroid Build Coastguard Worker new_key = get_new_key(old_key) 104*523fa7a6SAndroid Build Coastguard Worker 105*523fa7a6SAndroid Build Coastguard Worker new_state_dict[new_key] = state_dict[old_key] 106*523fa7a6SAndroid Build Coastguard Worker 107*523fa7a6SAndroid Build Coastguard Worker return new_state_dict 108*523fa7a6SAndroid Build Coastguard Worker 109*523fa7a6SAndroid Build Coastguard Worker def _feature_select(self, image_outputs): 110*523fa7a6SAndroid Build Coastguard Worker selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer] 111*523fa7a6SAndroid Build Coastguard Worker 112*523fa7a6SAndroid Build Coastguard Worker if self.vision_feature_select_strategy == "default": 113*523fa7a6SAndroid Build Coastguard Worker selected_image_feature = selected_image_feature[:, 1:] 114*523fa7a6SAndroid Build Coastguard Worker elif self.vision_feature_select_strategy == "full": 115*523fa7a6SAndroid Build Coastguard Worker selected_image_feature = selected_image_feature 116*523fa7a6SAndroid Build Coastguard Worker else: 117*523fa7a6SAndroid Build Coastguard Worker raise ValueError( 118*523fa7a6SAndroid Build Coastguard Worker f"Unexpected select feature: {self.vision_feature_select_strategy}" 119*523fa7a6SAndroid Build Coastguard Worker ) 120*523fa7a6SAndroid Build Coastguard Worker return selected_image_feature 121*523fa7a6SAndroid Build Coastguard Worker 122*523fa7a6SAndroid Build Coastguard Worker def get_model(self): 123*523fa7a6SAndroid Build Coastguard Worker return self.model_.get_model() 124*523fa7a6SAndroid Build Coastguard Worker 125*523fa7a6SAndroid Build Coastguard Worker def embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor: 126*523fa7a6SAndroid Build Coastguard Worker return self.model_.language_model.model.embed_tokens(tokens) 127*523fa7a6SAndroid Build Coastguard Worker 128*523fa7a6SAndroid Build Coastguard Worker def encode_images(self, images: torch.Tensor) -> torch.Tensor: 129*523fa7a6SAndroid Build Coastguard Worker images = images.to(dtype=self.model_.dtype) 130*523fa7a6SAndroid Build Coastguard Worker if type(images) is list: 131*523fa7a6SAndroid Build Coastguard Worker image_features = [] 132*523fa7a6SAndroid Build Coastguard Worker for image in images: 133*523fa7a6SAndroid Build Coastguard Worker image_forward_out = self.model_.vision_tower( 134*523fa7a6SAndroid Build Coastguard Worker image.to( 135*523fa7a6SAndroid Build Coastguard Worker device=self.model_.device, dtype=self.model_.dtype 136*523fa7a6SAndroid Build Coastguard Worker ).unsqueeze(0), 137*523fa7a6SAndroid Build Coastguard Worker output_hidden_states=True, 138*523fa7a6SAndroid Build Coastguard Worker ) 139*523fa7a6SAndroid Build Coastguard Worker image_feature = self._feature_select(image_forward_out).to(image.dtype) 140*523fa7a6SAndroid Build Coastguard Worker image_features.append(image_feature) 141*523fa7a6SAndroid Build Coastguard Worker else: 142*523fa7a6SAndroid Build Coastguard Worker image_forward_outs = self.model_.vision_tower( 143*523fa7a6SAndroid Build Coastguard Worker images.to(device=self.model_.device, dtype=self.model_.dtype), 144*523fa7a6SAndroid Build Coastguard Worker output_hidden_states=True, 145*523fa7a6SAndroid Build Coastguard Worker ) 146*523fa7a6SAndroid Build Coastguard Worker image_features = self._feature_select(image_forward_outs).to(images.dtype) 147*523fa7a6SAndroid Build Coastguard Worker image_features = self.model_.multi_modal_projector(image_features) 148*523fa7a6SAndroid Build Coastguard Worker return image_features 149*523fa7a6SAndroid Build Coastguard Worker 150*523fa7a6SAndroid Build Coastguard Worker def image_preprocess(self, img: torch.Tensor) -> torch.Tensor: 151*523fa7a6SAndroid Build Coastguard Worker target_h = self.image_processor.crop_size["height"] 152*523fa7a6SAndroid Build Coastguard Worker target_w = self.image_processor.crop_size["width"] 153*523fa7a6SAndroid Build Coastguard Worker # pad the image with median rgb value, to make a square 154*523fa7a6SAndroid Build Coastguard Worker l_pad = (target_w - img.shape[2]) // 2 155*523fa7a6SAndroid Build Coastguard Worker t_pad = (target_h - img.shape[1]) // 2 156*523fa7a6SAndroid Build Coastguard Worker # ceil division 157*523fa7a6SAndroid Build Coastguard Worker r_pad = -((target_w - img.shape[2]) // -2) 158*523fa7a6SAndroid Build Coastguard Worker b_pad = -((target_h - img.shape[1]) // -2) 159*523fa7a6SAndroid Build Coastguard Worker 160*523fa7a6SAndroid Build Coastguard Worker torch._check(l_pad >= 0) 161*523fa7a6SAndroid Build Coastguard Worker torch._check(t_pad >= 0) 162*523fa7a6SAndroid Build Coastguard Worker torch._check(r_pad >= 0) 163*523fa7a6SAndroid Build Coastguard Worker torch._check(b_pad >= 0) 164*523fa7a6SAndroid Build Coastguard Worker 165*523fa7a6SAndroid Build Coastguard Worker # This is different from the original implementation, due to export limitations. 166*523fa7a6SAndroid Build Coastguard Worker resized = torch.nn.functional.pad( 167*523fa7a6SAndroid Build Coastguard Worker img, 168*523fa7a6SAndroid Build Coastguard Worker (l_pad, r_pad, t_pad, b_pad), 169*523fa7a6SAndroid Build Coastguard Worker ) 170*523fa7a6SAndroid Build Coastguard Worker # originally: 171*523fa7a6SAndroid Build Coastguard Worker # resized = F.pad( 172*523fa7a6SAndroid Build Coastguard Worker # img, 173*523fa7a6SAndroid Build Coastguard Worker # padding=(l_pad, t_pad, r_pad, b_pad), 174*523fa7a6SAndroid Build Coastguard Worker # fill=tuple(int(x * 255) for x in self.image_mean), 175*523fa7a6SAndroid Build Coastguard Worker # ) 176*523fa7a6SAndroid Build Coastguard Worker 177*523fa7a6SAndroid Build Coastguard Worker # TODO: implement _upsample_bicubic_aa.out in portable kernel library. 178*523fa7a6SAndroid Build Coastguard Worker # here padded shape should be max(h, w) x max(h, w) 179*523fa7a6SAndroid Build Coastguard Worker # skipping resize for now due to missing _upsample_bicubic_aa kernel in portable 180*523fa7a6SAndroid Build Coastguard Worker # resized = resize( 181*523fa7a6SAndroid Build Coastguard Worker # padded, 182*523fa7a6SAndroid Build Coastguard Worker # size=[ 183*523fa7a6SAndroid Build Coastguard Worker # self.image_processor.crop_size["height"], 184*523fa7a6SAndroid Build Coastguard Worker # self.image_processor.crop_size["width"], 185*523fa7a6SAndroid Build Coastguard Worker # ], 186*523fa7a6SAndroid Build Coastguard Worker # interpolation="bicubic", 187*523fa7a6SAndroid Build Coastguard Worker # ) 188*523fa7a6SAndroid Build Coastguard Worker # torch._check(resized.size(1) == self.config.crop_size["height"]) 189*523fa7a6SAndroid Build Coastguard Worker # torch._check(resized.size(2) == self.config.crop_size["width"]) 190*523fa7a6SAndroid Build Coastguard Worker # print(resized.shape) 191*523fa7a6SAndroid Build Coastguard Worker # cropped = F.center_crop(img, output_size=[w, w]) 192*523fa7a6SAndroid Build Coastguard Worker # print(cropped.shape) 193*523fa7a6SAndroid Build Coastguard Worker scaled = resized * self.image_processor.rescale_factor 194*523fa7a6SAndroid Build Coastguard Worker # print(scaled) 195*523fa7a6SAndroid Build Coastguard Worker normed = F.normalize( 196*523fa7a6SAndroid Build Coastguard Worker scaled, self.image_processor.image_mean, self.image_processor.image_std 197*523fa7a6SAndroid Build Coastguard Worker ) 198*523fa7a6SAndroid Build Coastguard Worker # print(normed) 199*523fa7a6SAndroid Build Coastguard Worker return normed.unsqueeze(0) 200*523fa7a6SAndroid Build Coastguard Worker 201*523fa7a6SAndroid Build Coastguard Worker def step( 202*523fa7a6SAndroid Build Coastguard Worker self, token: torch.Tensor, input_pos: Optional[torch.Tensor] = None 203*523fa7a6SAndroid Build Coastguard Worker ) -> torch.Tensor: 204*523fa7a6SAndroid Build Coastguard Worker """Input is one token. Return logits for next token.""" 205*523fa7a6SAndroid Build Coastguard Worker token_embeds = self.embed_tokens(token).unsqueeze(0) 206*523fa7a6SAndroid Build Coastguard Worker return self.text_model.forward(None, input_pos, token_embeds) 207*523fa7a6SAndroid Build Coastguard Worker 208*523fa7a6SAndroid Build Coastguard Worker def image_embedding(self, images: torch.Tensor) -> torch.Tensor: 209*523fa7a6SAndroid Build Coastguard Worker preprocessed_img = self.image_preprocess(images) 210*523fa7a6SAndroid Build Coastguard Worker return self.encode_images(preprocessed_img) 211*523fa7a6SAndroid Build Coastguard Worker 212*523fa7a6SAndroid Build Coastguard Worker def prefill_embedding( 213*523fa7a6SAndroid Build Coastguard Worker self, 214*523fa7a6SAndroid Build Coastguard Worker prompt_before_image: torch.Tensor, 215*523fa7a6SAndroid Build Coastguard Worker images: torch.Tensor, 216*523fa7a6SAndroid Build Coastguard Worker prompt_after_image: torch.Tensor, 217*523fa7a6SAndroid Build Coastguard Worker ) -> torch.Tensor: 218*523fa7a6SAndroid Build Coastguard Worker image_embeds = self.image_embedding(images) 219*523fa7a6SAndroid Build Coastguard Worker embeds_before_img = self.embed_tokens(prompt_before_image) 220*523fa7a6SAndroid Build Coastguard Worker embeds_after_img = self.embed_tokens(prompt_after_image) 221*523fa7a6SAndroid Build Coastguard Worker result = torch.cat((embeds_before_img, image_embeds, embeds_after_img), dim=1) 222*523fa7a6SAndroid Build Coastguard Worker return result 223*523fa7a6SAndroid Build Coastguard Worker 224*523fa7a6SAndroid Build Coastguard Worker # prefill using the in house text_model of llama transformer 225*523fa7a6SAndroid Build Coastguard Worker def prefill( 226*523fa7a6SAndroid Build Coastguard Worker self, 227*523fa7a6SAndroid Build Coastguard Worker prompt_before_image: torch.Tensor, 228*523fa7a6SAndroid Build Coastguard Worker images: torch.Tensor, 229*523fa7a6SAndroid Build Coastguard Worker prompt_after_image: torch.Tensor, 230*523fa7a6SAndroid Build Coastguard Worker ) -> Tuple[int, torch.Tensor]: 231*523fa7a6SAndroid Build Coastguard Worker """Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead.""" 232*523fa7a6SAndroid Build Coastguard Worker embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image) 233*523fa7a6SAndroid Build Coastguard Worker # returns the prefilled token length too, because the text model generates one logits in each forward call. 234*523fa7a6SAndroid Build Coastguard Worker return embeds.shape[1], self.text_model.forward(None, torch.tensor([0]), embeds) 235*523fa7a6SAndroid Build Coastguard Worker 236*523fa7a6SAndroid Build Coastguard Worker # reference prefill using the text model in HF 237*523fa7a6SAndroid Build Coastguard Worker def prefill_ref( 238*523fa7a6SAndroid Build Coastguard Worker self, 239*523fa7a6SAndroid Build Coastguard Worker prompt_before_image: torch.Tensor, 240*523fa7a6SAndroid Build Coastguard Worker images: torch.Tensor, 241*523fa7a6SAndroid Build Coastguard Worker prompt_after_image: torch.Tensor, 242*523fa7a6SAndroid Build Coastguard Worker ) -> torch.Tensor: 243*523fa7a6SAndroid Build Coastguard Worker """Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead.""" 244*523fa7a6SAndroid Build Coastguard Worker embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image) 245*523fa7a6SAndroid Build Coastguard Worker return LlamaForCausalLM.forward( 246*523fa7a6SAndroid Build Coastguard Worker self.model_.language_model, 247*523fa7a6SAndroid Build Coastguard Worker inputs_embeds=embeds, 248*523fa7a6SAndroid Build Coastguard Worker return_dict=False, 249*523fa7a6SAndroid Build Coastguard Worker use_cache=False, 250*523fa7a6SAndroid Build Coastguard Worker output_hidden_states=False, 251*523fa7a6SAndroid Build Coastguard Worker ) 252*523fa7a6SAndroid Build Coastguard Worker 253*523fa7a6SAndroid Build Coastguard Worker def forward( 254*523fa7a6SAndroid Build Coastguard Worker self, 255*523fa7a6SAndroid Build Coastguard Worker images: torch.Tensor, 256*523fa7a6SAndroid Build Coastguard Worker ) -> torch.Tensor: 257*523fa7a6SAndroid Build Coastguard Worker return self.image_embedding(images) 258*523fa7a6SAndroid Build Coastguard Worker 259*523fa7a6SAndroid Build Coastguard Worker 260*523fa7a6SAndroid Build Coastguard Workerclass LlavaModel(EagerModelBase): 261*523fa7a6SAndroid Build Coastguard Worker def __init__(self, use_sdpa_with_kv_cache_op=True, max_seq_len=768): 262*523fa7a6SAndroid Build Coastguard Worker self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op 263*523fa7a6SAndroid Build Coastguard Worker self.max_seq_len = max_seq_len 264*523fa7a6SAndroid Build Coastguard Worker self.processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") 265*523fa7a6SAndroid Build Coastguard Worker self.tokenizer = self.processor.tokenizer 266*523fa7a6SAndroid Build Coastguard Worker self.image_processor = self.processor.image_processor 267*523fa7a6SAndroid Build Coastguard Worker self.model = LlavaForConditionalGeneration.from_pretrained( 268*523fa7a6SAndroid Build Coastguard Worker "llava-hf/llava-1.5-7b-hf", 269*523fa7a6SAndroid Build Coastguard Worker device_map="cpu", 270*523fa7a6SAndroid Build Coastguard Worker ) 271*523fa7a6SAndroid Build Coastguard Worker self.image = Image.open( 272*523fa7a6SAndroid Build Coastguard Worker requests.get( 273*523fa7a6SAndroid Build Coastguard Worker "https://llava-vl.github.io/static/images/view.jpg", stream=True 274*523fa7a6SAndroid Build Coastguard Worker ).raw 275*523fa7a6SAndroid Build Coastguard Worker ) 276*523fa7a6SAndroid Build Coastguard Worker self.prompt = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image> 277*523fa7a6SAndroid Build Coastguard WorkerWhat are the things I should be cautious about when I visit here? ASSISTANT:""" 278*523fa7a6SAndroid Build Coastguard Worker self.model_name = "llava-1.5-7b-hf" 279*523fa7a6SAndroid Build Coastguard Worker # set input to None and initialize them lazily 280*523fa7a6SAndroid Build Coastguard Worker self.input = None 281*523fa7a6SAndroid Build Coastguard Worker self.resized_image = None 282*523fa7a6SAndroid Build Coastguard Worker 283*523fa7a6SAndroid Build Coastguard Worker def get_eager_model(self): 284*523fa7a6SAndroid Build Coastguard Worker model = Llava( 285*523fa7a6SAndroid Build Coastguard Worker self.model, 286*523fa7a6SAndroid Build Coastguard Worker self.image_processor, 287*523fa7a6SAndroid Build Coastguard Worker self.use_sdpa_with_kv_cache_op, 288*523fa7a6SAndroid Build Coastguard Worker self.max_seq_len, 289*523fa7a6SAndroid Build Coastguard Worker ) 290*523fa7a6SAndroid Build Coastguard Worker model.to(dtype=torch.float32) 291*523fa7a6SAndroid Build Coastguard Worker return model 292*523fa7a6SAndroid Build Coastguard Worker 293*523fa7a6SAndroid Build Coastguard Worker def get_example_inputs(self): 294*523fa7a6SAndroid Build Coastguard Worker """Returns a resized image as input to model.forward().""" 295*523fa7a6SAndroid Build Coastguard Worker if self.resized_image: 296*523fa7a6SAndroid Build Coastguard Worker return self.resized_image 297*523fa7a6SAndroid Build Coastguard Worker resized = prepare_image( 298*523fa7a6SAndroid Build Coastguard Worker self.image, 299*523fa7a6SAndroid Build Coastguard Worker self.image_processor.crop_size["height"], 300*523fa7a6SAndroid Build Coastguard Worker self.image_processor.crop_size["width"], 301*523fa7a6SAndroid Build Coastguard Worker ) 302*523fa7a6SAndroid Build Coastguard Worker self.resized_image = (resized,) 303*523fa7a6SAndroid Build Coastguard Worker return self.resized_image 304*523fa7a6SAndroid Build Coastguard Worker 305*523fa7a6SAndroid Build Coastguard Worker def get_inputs_for_prefill(self): 306*523fa7a6SAndroid Build Coastguard Worker """Returns prompts as well as image.""" 307*523fa7a6SAndroid Build Coastguard Worker if self.input: 308*523fa7a6SAndroid Build Coastguard Worker return self.input 309*523fa7a6SAndroid Build Coastguard Worker self.input_ids = self.tokenizer.encode(self.prompt, return_tensors="pt").cpu() 310*523fa7a6SAndroid Build Coastguard Worker index = torch.where(self.input_ids == self.model.config.image_token_index)[1] 311*523fa7a6SAndroid Build Coastguard Worker self.prompt_before_image = self.input_ids[:, :index] 312*523fa7a6SAndroid Build Coastguard Worker # print(prompt_before_image.shape) 313*523fa7a6SAndroid Build Coastguard Worker self.prompt_after_image = self.input_ids[:, index + 1 :] 314*523fa7a6SAndroid Build Coastguard Worker # print(prompt_after_image.shape) 315*523fa7a6SAndroid Build Coastguard Worker self.input = ( 316*523fa7a6SAndroid Build Coastguard Worker self.prompt_before_image, 317*523fa7a6SAndroid Build Coastguard Worker *self.get_example_inputs(), 318*523fa7a6SAndroid Build Coastguard Worker self.prompt_after_image, 319*523fa7a6SAndroid Build Coastguard Worker ) 320*523fa7a6SAndroid Build Coastguard Worker return self.input 321*523fa7a6SAndroid Build Coastguard Worker 322*523fa7a6SAndroid Build Coastguard Worker def get_dynamic_shapes(self): 323*523fa7a6SAndroid Build Coastguard Worker return self._get_image_dynamic_shapes() 324*523fa7a6SAndroid Build Coastguard Worker 325*523fa7a6SAndroid Build Coastguard Worker def _get_image_dynamic_shapes(self): 326*523fa7a6SAndroid Build Coastguard Worker # only support even number of height and width for now 327*523fa7a6SAndroid Build Coastguard Worker _height = Dim( 328*523fa7a6SAndroid Build Coastguard Worker "_height", min=1, max=self.image_processor.crop_size["height"] // 2 329*523fa7a6SAndroid Build Coastguard Worker ) 330*523fa7a6SAndroid Build Coastguard Worker _width = Dim("_width", min=1, max=self.image_processor.crop_size["width"] // 2) 331*523fa7a6SAndroid Build Coastguard Worker height = 2 * _height 332*523fa7a6SAndroid Build Coastguard Worker width = 2 * _width 333*523fa7a6SAndroid Build Coastguard Worker dynamic_shapes = [{1: height, 2: width}] 334*523fa7a6SAndroid Build Coastguard Worker return dynamic_shapes 335*523fa7a6SAndroid Build Coastguard Worker 336*523fa7a6SAndroid Build Coastguard Worker def _get_prompt_dynamic_shapes(self): 337*523fa7a6SAndroid Build Coastguard Worker dim = torch.export.Dim("token_dim", min=2, max=self.max_seq_len) 338*523fa7a6SAndroid Build Coastguard Worker text_model_dynamic_shapes = ({0: 1}, {1: dim}) 339*523fa7a6SAndroid Build Coastguard Worker return text_model_dynamic_shapes 340