xref: /aosp_15_r20/external/executorch/examples/models/llava/runner/llava_image_prefiller.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Given a image tensor, prefill the KV cache of LLaVA.
10 
11 #pragma once
12 
13 #include <executorch/extension/llm/runner/image_prefiller.h>
14 #include <executorch/extension/tensor/tensor.h>
15 
16 namespace example {
17 
18 class ET_EXPERIMENTAL LlavaImagePrefiller
19     : public ::executorch::extension::llm::ImagePrefiller {
20  public:
LlavaImagePrefiller(::executorch::extension::Module * module)21   LlavaImagePrefiller(::executorch::extension::Module* module)
22       : ImagePrefiller(module){};
23   /**
24    * Prefill an LLM Module with the given image input.
25    * @param image The image input to LLaVa.
26    * @param start_pos The starting position in KV cache of the input in the LLM
27    * @return logits of the image prefill.
28    */
prefill(::executorch::extension::llm::Image & image,int64_t & start_pos)29   inline ::executorch::runtime::Result<exec_aten::Tensor> prefill(
30       ::executorch::extension::llm::Image& image,
31       int64_t& start_pos) override {
32     auto image_tensor = executorch::extension::from_blob(
33         image.data.data(),
34         {3, image.height, image.width},
35         ::executorch::aten::ScalarType::Byte);
36     // Run image encoder
37     auto image_encoder_outputs =
38         ET_UNWRAP(module_->execute(kImageEncoderMethod, image_tensor));
39 
40     // inputs:[start_pos, embeds]
41     auto start_pos_tensor = executorch::extension::from_blob(
42         &start_pos, {1}, ::executorch::aten::ScalarType::Long);
43 
44     // Run text model
45     auto outputs_res = ET_UNWRAP(module_->execute(
46         kTextModelMethod, {start_pos_tensor, image_encoder_outputs[0]}));
47     ET_CHECK_MSG(
48         outputs_res[0].isTensor(),
49         "Non Tensor Output returned from executing image prefill");
50 
51     // Update the start_pos, which is only available inside this function.
52     // outputs_res can have only one logits.
53     start_pos += image_encoder_outputs[0].toTensor().size(1);
54 
55     return outputs_res[0].toTensor();
56   }
57 
58   /**
59    * Load the Module for image prefill purpose.
60    * @return The error code.
61    */
load()62   inline ::executorch::runtime::Error load() override {
63     if (is_method_loaded()) {
64       return ::executorch::runtime::Error::Ok;
65     }
66     ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kImageEncoderMethod));
67     ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTextModelMethod));
68     return ::executorch::runtime::Error::Ok;
69   }
70 
71   /**
72    * Check if the required methods in the Module is loaded.
73    * @return True if the Module is loaded, false otherwise.
74    */
is_method_loaded()75   inline bool is_method_loaded() override {
76     ::executorch::runtime::Result<std::unordered_set<std::string>> methods_res =
77         module_->method_names();
78     if (methods_res.error() != ::executorch::runtime::Error::Ok) {
79       ET_CHECK_MSG(false, "Failed to get method names");
80     }
81     std::unordered_set<std::string> methods = methods_res.get();
82     bool methods_exist = methods.find(kImageEncoderMethod) != methods.end() &&
83         methods.find(kTextModelMethod) != methods.end();
84     if (!methods_exist) {
85       for (const auto& method : methods) {
86         ET_LOG(Error, "Method: %s", method.c_str());
87       }
88       ET_CHECK_MSG(
89           methods_exist,
90           "Missing required methods (%s, %s) in the model",
91           kImageEncoderMethod.c_str(),
92           kTextModelMethod.c_str());
93     }
94     bool methods_loaded = module_->is_method_loaded(kImageEncoderMethod) &&
95         module_->is_method_loaded(kTextModelMethod);
96     return methods_loaded;
97   }
98 
99   inline static const std::string kImageEncoderMethod = "image_encoder";
100   inline static const std::string kTextModelMethod = "text_model";
101 };
102 
103 } // namespace example
104