1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 // Given inputs, run a text decoder in Llava and return the output. 10 11 #pragma once 12 13 #include <executorch/extension/llm/runner/text_decoder_runner.h> 14 15 namespace example { 16 17 class ET_EXPERIMENTAL LlavaTextDecoderRunner 18 : public executorch::extension::llm::TextDecoderRunner { 19 public: LlavaTextDecoderRunner(executorch::extension::Module * module,int32_t vocab_size,float temperature)20 LlavaTextDecoderRunner( 21 executorch::extension::Module* module, 22 int32_t vocab_size, 23 float temperature) 24 : TextDecoderRunner(module, true, vocab_size, temperature){}; 25 step(executorch::extension::TensorPtr & tokens,executorch::extension::TensorPtr & start_pos)26 inline executorch::runtime::Result<exec_aten::Tensor> step( 27 executorch::extension::TensorPtr& tokens, 28 executorch::extension::TensorPtr& start_pos) override { 29 // run token embedding 30 auto token_embedding_outputs = 31 ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, tokens)); 32 33 // run text model 34 auto outputs_res = ET_UNWRAP(module_->execute( 35 kTextModelMethod, {start_pos, token_embedding_outputs[0]})); 36 37 ET_CHECK_MSG( 38 outputs_res.size() == 1, 39 "More then one output returned from executing LLM."); 40 ET_CHECK_MSG( 41 outputs_res[0].isTensor(), 42 "Non Tensor Output returned from executing LLM"); 43 44 // Return the logits tensor 45 return outputs_res[0].toTensor(); 46 } 47 48 /** 49 * Load the Module for text decode purpose. 50 * @return The error code. 51 */ load()52 inline executorch::runtime::Error load() override { 53 if (is_method_loaded()) { 54 return executorch::runtime::Error::Ok; 55 } 56 ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTokenEmbeddingMethod)); 57 ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTextModelMethod)); 58 return executorch::runtime::Error::Ok; 59 } 60 61 /** 62 * Check if the required methods in the Module is loaded. 63 * @return True if the Module is loaded, false otherwise. 64 */ is_method_loaded()65 inline bool is_method_loaded() override { 66 executorch::runtime::Result<std::unordered_set<std::string>> methods_res = 67 module_->method_names(); 68 if (methods_res.error() != executorch::runtime::Error::Ok) { 69 ET_CHECK_MSG(false, "Failed to get method names"); 70 } 71 std::unordered_set<std::string> methods = methods_res.get(); 72 bool methods_exist = methods.find(kTokenEmbeddingMethod) != methods.end() && 73 methods.find(kTextModelMethod) != methods.end(); 74 if (!methods_exist) { 75 for (const auto& method : methods) { 76 ET_LOG(Error, "Method: %s", method.c_str()); 77 } 78 ET_CHECK_MSG( 79 methods_exist, 80 "Missing required methods (%s, %s) in the model", 81 kTokenEmbeddingMethod.c_str(), 82 kTextModelMethod.c_str()); 83 } 84 bool methods_loaded = module_->is_method_loaded(kTokenEmbeddingMethod) && 85 module_->is_method_loaded(kTextModelMethod); 86 return methods_loaded; 87 } 88 89 inline static const std::string kTokenEmbeddingMethod = "token_embedding"; 90 inline static const std::string kTextModelMethod = "text_model"; 91 }; 92 93 } // namespace example 94