1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8import unittest 9 10import torch 11from executorch.examples.models.llava.export_llava import export_all 12 13from executorch.examples.models.llava.model import LlavaModel 14 15# import order matters. We need to import portable_lib first since it contains the static op registry 16# which will be used in the import of custom ops. Otherwise, the registration of custom ops will be skipped. 17# I don't know how to mute UFMT so I'm just using if True: to avoid the error 18from executorch.extension.pybindings.portable_lib import ( 19 _load_for_executorch_from_buffer, 20) 21from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip 22from executorch.kernels import quantized # noqa # usort: skip 23 24logging.basicConfig(level=logging.INFO) 25logger = logging.getLogger(__name__) 26 27 28class TestLlava(unittest.TestCase): 29 def setUp(self): 30 self.llava_model = LlavaModel() 31 self.llava = self.llava_model.get_eager_model() 32 self.prompt_before_image, self.resized, self.prompt_after_image = ( 33 self.llava_model.get_inputs_for_prefill() 34 ) 35 36 def test_prefill_logits(self): 37 # For efficiency, the implemented prefill function only outputs the last logits. 38 _, prefill_logits = self.llava.prefill( 39 self.prompt_before_image, self.resized, self.prompt_after_image 40 ) 41 # The reference implementation in HF genetates the full logits. Get the last one. 42 prefill_logits_ref = self.llava.prefill_ref( 43 self.prompt_before_image, self.resized, self.prompt_after_image 44 )[0][:, -1, :] 45 self.assertTrue(torch.allclose(prefill_logits, prefill_logits_ref, atol=3e-2)) 46 47 def test_generated_output(self): 48 # source of truth, using HF llava 49 preprocessed = self.llava.image_preprocess(self.resized) 50 with torch.inference_mode(): 51 output_ids = self.llava_model.model.generate( 52 self.llava_model.input_ids, 53 pixel_values=preprocessed, 54 do_sample=False, 55 num_beams=1, 56 max_new_tokens=5, 57 use_cache=True, 58 ) 59 # the output includes prompt, removing it 60 output_ids = output_ids[:, -5:] 61 ref_outputs = self.llava_model.tokenizer.batch_decode( 62 output_ids, skip_special_tokens=True 63 )[0].strip() 64 65 # being tested, using llama_transformer 66 context_len, prefill_logits = self.llava.prefill( 67 self.prompt_before_image, self.resized, self.prompt_after_image 68 ) 69 # Always generate one token at a time. 70 new_tokens = [torch.argmax(prefill_logits).item()] 71 for i in range(4): 72 logits = self.llava.step( 73 torch.tensor([new_tokens[i]]), torch.tensor([context_len + i]) 74 ) 75 new_tokens.append(torch.argmax(logits[-1, :]).item()) 76 77 outputs = self.llava_model.tokenizer.batch_decode( 78 torch.tensor([new_tokens]), skip_special_tokens=True 79 )[0].strip() 80 self.assertEqual(outputs, ref_outputs) 81 82 def test_llava_export(self): 83 # export llava and make sure e2e works 84 llava_model = LlavaModel(use_sdpa_with_kv_cache_op=True) 85 86 prompt_before_image, resized, prompt_after_image = ( 87 llava_model.get_inputs_for_prefill() 88 ) 89 executorch_program = export_all(llava_model) 90 llava_module = _load_for_executorch_from_buffer(executorch_program.buffer) 91 92 start_pos = 0 93 # pte prefill prompt before img 94 pte_embeds_before_img = llava_module.run_method( 95 "token_embedding", (prompt_before_image,) 96 )[0] 97 llava_module.run_method( 98 "text_model", 99 (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img), 100 ) 101 102 # Update the start_pos. start_pos is used in kv cache. The source of truth 103 # of the delta length is from the embeddings, not from the logits. 104 start_pos += pte_embeds_before_img.shape[1] 105 106 # pte prefill image 107 pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0] 108 llava_module.run_method( 109 "text_model", 110 ( 111 torch.tensor([start_pos], dtype=torch.int64), 112 pte_embeds_img, 113 ), 114 ) 115 116 # Update the logits for each prefill (kv cache) step. 117 start_pos += pte_embeds_img.shape[1] 118 119 # pte prefill prompt after img 120 pte_embeds_after_img = llava_module.run_method( 121 "token_embedding", (prompt_after_image,) 122 )[0] 123 pte_prefill_after_img = llava_module.run_method( 124 "text_model", 125 (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img), 126 )[0] 127 128 # Update the logits for each prefill (kv cache) step. 129 start_pos += pte_embeds_after_img.shape[1] 130 131 # being tested, using llama_transformer 132 new_tokens = [torch.argmax(pte_prefill_after_img).item()] 133 # TODO: uncomment this line 134 # self.assertEquals(new_tokens[0], 1932) # When 135 for i in range(4): 136 print(i, llava_model.tokenizer.decode(new_tokens[i])) 137 token_embeds = llava_module.run_method( 138 "token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),) 139 )[0] 140 logits = llava_module.run_method( 141 "text_model", 142 (torch.tensor([start_pos + i], dtype=torch.int64), token_embeds), 143 )[0] 144 new_tokens.append(torch.argmax(logits).item()) 145 146 outputs = llava_model.tokenizer.batch_decode( 147 torch.tensor([new_tokens]), skip_special_tokens=True 148 )[0].strip() 149 print(outputs) 150 self.assertEqual(len(new_tokens), 5) 151