1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 // Given a image tensor, prefill the KV cache of LLaVA. 10 11 #pragma once 12 13 #include <executorch/extension/llm/runner/image_prefiller.h> 14 #include <executorch/extension/tensor/tensor.h> 15 16 namespace example { 17 18 class ET_EXPERIMENTAL LlavaImagePrefiller 19 : public ::executorch::extension::llm::ImagePrefiller { 20 public: LlavaImagePrefiller(::executorch::extension::Module * module)21 LlavaImagePrefiller(::executorch::extension::Module* module) 22 : ImagePrefiller(module){}; 23 /** 24 * Prefill an LLM Module with the given image input. 25 * @param image The image input to LLaVa. 26 * @param start_pos The starting position in KV cache of the input in the LLM 27 * @return logits of the image prefill. 28 */ prefill(::executorch::extension::llm::Image & image,int64_t & start_pos)29 inline ::executorch::runtime::Result<exec_aten::Tensor> prefill( 30 ::executorch::extension::llm::Image& image, 31 int64_t& start_pos) override { 32 auto image_tensor = executorch::extension::from_blob( 33 image.data.data(), 34 {3, image.height, image.width}, 35 ::executorch::aten::ScalarType::Byte); 36 // Run image encoder 37 auto image_encoder_outputs = 38 ET_UNWRAP(module_->execute(kImageEncoderMethod, image_tensor)); 39 40 // inputs:[start_pos, embeds] 41 auto start_pos_tensor = executorch::extension::from_blob( 42 &start_pos, {1}, ::executorch::aten::ScalarType::Long); 43 44 // Run text model 45 auto outputs_res = ET_UNWRAP(module_->execute( 46 kTextModelMethod, {start_pos_tensor, image_encoder_outputs[0]})); 47 ET_CHECK_MSG( 48 outputs_res[0].isTensor(), 49 "Non Tensor Output returned from executing image prefill"); 50 51 // Update the start_pos, which is only available inside this function. 52 // outputs_res can have only one logits. 53 start_pos += image_encoder_outputs[0].toTensor().size(1); 54 55 return outputs_res[0].toTensor(); 56 } 57 58 /** 59 * Load the Module for image prefill purpose. 60 * @return The error code. 61 */ load()62 inline ::executorch::runtime::Error load() override { 63 if (is_method_loaded()) { 64 return ::executorch::runtime::Error::Ok; 65 } 66 ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kImageEncoderMethod)); 67 ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTextModelMethod)); 68 return ::executorch::runtime::Error::Ok; 69 } 70 71 /** 72 * Check if the required methods in the Module is loaded. 73 * @return True if the Module is loaded, false otherwise. 74 */ is_method_loaded()75 inline bool is_method_loaded() override { 76 ::executorch::runtime::Result<std::unordered_set<std::string>> methods_res = 77 module_->method_names(); 78 if (methods_res.error() != ::executorch::runtime::Error::Ok) { 79 ET_CHECK_MSG(false, "Failed to get method names"); 80 } 81 std::unordered_set<std::string> methods = methods_res.get(); 82 bool methods_exist = methods.find(kImageEncoderMethod) != methods.end() && 83 methods.find(kTextModelMethod) != methods.end(); 84 if (!methods_exist) { 85 for (const auto& method : methods) { 86 ET_LOG(Error, "Method: %s", method.c_str()); 87 } 88 ET_CHECK_MSG( 89 methods_exist, 90 "Missing required methods (%s, %s) in the model", 91 kImageEncoderMethod.c_str(), 92 kTextModelMethod.c_str()); 93 } 94 bool methods_loaded = module_->is_method_loaded(kImageEncoderMethod) && 95 module_->is_method_loaded(kTextModelMethod); 96 return methods_loaded; 97 } 98 99 inline static const std::string kImageEncoderMethod = "image_encoder"; 100 inline static const std::string kTextModelMethod = "text_model"; 101 }; 102 103 } // namespace example 104