xref: /aosp_15_r20/external/executorch/examples/models/llava/test/test_llava.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8import unittest
9
10import torch
11from executorch.examples.models.llava.export_llava import export_all
12
13from executorch.examples.models.llava.model import LlavaModel
14
15# import order matters. We need to import portable_lib first since it contains the static op registry
16# which will be used in the import of custom ops. Otherwise, the registration of custom ops will be skipped.
17# I don't know how to mute UFMT so I'm just using if True: to avoid the error
18from executorch.extension.pybindings.portable_lib import (
19    _load_for_executorch_from_buffer,
20)
21from executorch.extension.llm.custom_ops import sdpa_with_kv_cache  # noqa # usort: skip
22from executorch.kernels import quantized  # noqa # usort: skip
23
24logging.basicConfig(level=logging.INFO)
25logger = logging.getLogger(__name__)
26
27
28class TestLlava(unittest.TestCase):
29    def setUp(self):
30        self.llava_model = LlavaModel()
31        self.llava = self.llava_model.get_eager_model()
32        self.prompt_before_image, self.resized, self.prompt_after_image = (
33            self.llava_model.get_inputs_for_prefill()
34        )
35
36    def test_prefill_logits(self):
37        # For efficiency, the implemented prefill function only outputs the last logits.
38        _, prefill_logits = self.llava.prefill(
39            self.prompt_before_image, self.resized, self.prompt_after_image
40        )
41        # The reference implementation in HF genetates the full logits. Get the last one.
42        prefill_logits_ref = self.llava.prefill_ref(
43            self.prompt_before_image, self.resized, self.prompt_after_image
44        )[0][:, -1, :]
45        self.assertTrue(torch.allclose(prefill_logits, prefill_logits_ref, atol=3e-2))
46
47    def test_generated_output(self):
48        # source of truth, using HF llava
49        preprocessed = self.llava.image_preprocess(self.resized)
50        with torch.inference_mode():
51            output_ids = self.llava_model.model.generate(
52                self.llava_model.input_ids,
53                pixel_values=preprocessed,
54                do_sample=False,
55                num_beams=1,
56                max_new_tokens=5,
57                use_cache=True,
58            )
59        # the output includes prompt, removing it
60        output_ids = output_ids[:, -5:]
61        ref_outputs = self.llava_model.tokenizer.batch_decode(
62            output_ids, skip_special_tokens=True
63        )[0].strip()
64
65        # being tested, using llama_transformer
66        context_len, prefill_logits = self.llava.prefill(
67            self.prompt_before_image, self.resized, self.prompt_after_image
68        )
69        # Always generate one token at a time.
70        new_tokens = [torch.argmax(prefill_logits).item()]
71        for i in range(4):
72            logits = self.llava.step(
73                torch.tensor([new_tokens[i]]), torch.tensor([context_len + i])
74            )
75            new_tokens.append(torch.argmax(logits[-1, :]).item())
76
77        outputs = self.llava_model.tokenizer.batch_decode(
78            torch.tensor([new_tokens]), skip_special_tokens=True
79        )[0].strip()
80        self.assertEqual(outputs, ref_outputs)
81
82    def test_llava_export(self):
83        # export llava and make sure e2e works
84        llava_model = LlavaModel(use_sdpa_with_kv_cache_op=True)
85
86        prompt_before_image, resized, prompt_after_image = (
87            llava_model.get_inputs_for_prefill()
88        )
89        executorch_program = export_all(llava_model)
90        llava_module = _load_for_executorch_from_buffer(executorch_program.buffer)
91
92        start_pos = 0
93        # pte prefill prompt before img
94        pte_embeds_before_img = llava_module.run_method(
95            "token_embedding", (prompt_before_image,)
96        )[0]
97        llava_module.run_method(
98            "text_model",
99            (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
100        )
101
102        # Update the start_pos. start_pos is used in kv cache. The source of truth
103        # of the delta length is from the embeddings, not from the logits.
104        start_pos += pte_embeds_before_img.shape[1]
105
106        # pte prefill image
107        pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0]
108        llava_module.run_method(
109            "text_model",
110            (
111                torch.tensor([start_pos], dtype=torch.int64),
112                pte_embeds_img,
113            ),
114        )
115
116        # Update the logits for each prefill (kv cache) step.
117        start_pos += pte_embeds_img.shape[1]
118
119        # pte prefill prompt after img
120        pte_embeds_after_img = llava_module.run_method(
121            "token_embedding", (prompt_after_image,)
122        )[0]
123        pte_prefill_after_img = llava_module.run_method(
124            "text_model",
125            (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
126        )[0]
127
128        # Update the logits for each prefill (kv cache) step.
129        start_pos += pte_embeds_after_img.shape[1]
130
131        # being tested, using llama_transformer
132        new_tokens = [torch.argmax(pte_prefill_after_img).item()]
133        # TODO: uncomment this line
134        # self.assertEquals(new_tokens[0], 1932)  # When
135        for i in range(4):
136            print(i, llava_model.tokenizer.decode(new_tokens[i]))
137            token_embeds = llava_module.run_method(
138                "token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),)
139            )[0]
140            logits = llava_module.run_method(
141                "text_model",
142                (torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
143            )[0]
144            new_tokens.append(torch.argmax(logits).item())
145
146        outputs = llava_model.tokenizer.batch_decode(
147            torch.tensor([new_tokens]), skip_special_tokens=True
148        )[0].strip()
149        print(outputs)
150        self.assertEqual(len(new_tokens), 5)
151