xref: /aosp_15_r20/external/executorch/examples/models/llava/runner/llava_text_decoder_runner.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Given inputs, run a text decoder in Llava and return the output.
10 
11 #pragma once
12 
13 #include <executorch/extension/llm/runner/text_decoder_runner.h>
14 
15 namespace example {
16 
17 class ET_EXPERIMENTAL LlavaTextDecoderRunner
18     : public executorch::extension::llm::TextDecoderRunner {
19  public:
LlavaTextDecoderRunner(executorch::extension::Module * module,int32_t vocab_size,float temperature)20   LlavaTextDecoderRunner(
21       executorch::extension::Module* module,
22       int32_t vocab_size,
23       float temperature)
24       : TextDecoderRunner(module, true, vocab_size, temperature){};
25 
step(executorch::extension::TensorPtr & tokens,executorch::extension::TensorPtr & start_pos)26   inline executorch::runtime::Result<exec_aten::Tensor> step(
27       executorch::extension::TensorPtr& tokens,
28       executorch::extension::TensorPtr& start_pos) override {
29     // run token embedding
30     auto token_embedding_outputs =
31         ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, tokens));
32 
33     // run text model
34     auto outputs_res = ET_UNWRAP(module_->execute(
35         kTextModelMethod, {start_pos, token_embedding_outputs[0]}));
36 
37     ET_CHECK_MSG(
38         outputs_res.size() == 1,
39         "More then one output returned from executing LLM.");
40     ET_CHECK_MSG(
41         outputs_res[0].isTensor(),
42         "Non Tensor Output returned from executing LLM");
43 
44     // Return the logits tensor
45     return outputs_res[0].toTensor();
46   }
47 
48   /**
49    * Load the Module for text decode purpose.
50    * @return The error code.
51    */
load()52   inline executorch::runtime::Error load() override {
53     if (is_method_loaded()) {
54       return executorch::runtime::Error::Ok;
55     }
56     ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTokenEmbeddingMethod));
57     ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTextModelMethod));
58     return executorch::runtime::Error::Ok;
59   }
60 
61   /**
62    * Check if the required methods in the Module is loaded.
63    * @return True if the Module is loaded, false otherwise.
64    */
is_method_loaded()65   inline bool is_method_loaded() override {
66     executorch::runtime::Result<std::unordered_set<std::string>> methods_res =
67         module_->method_names();
68     if (methods_res.error() != executorch::runtime::Error::Ok) {
69       ET_CHECK_MSG(false, "Failed to get method names");
70     }
71     std::unordered_set<std::string> methods = methods_res.get();
72     bool methods_exist = methods.find(kTokenEmbeddingMethod) != methods.end() &&
73         methods.find(kTextModelMethod) != methods.end();
74     if (!methods_exist) {
75       for (const auto& method : methods) {
76         ET_LOG(Error, "Method: %s", method.c_str());
77       }
78       ET_CHECK_MSG(
79           methods_exist,
80           "Missing required methods (%s, %s) in the model",
81           kTokenEmbeddingMethod.c_str(),
82           kTextModelMethod.c_str());
83     }
84     bool methods_loaded = module_->is_method_loaded(kTokenEmbeddingMethod) &&
85         module_->is_method_loaded(kTextModelMethod);
86     return methods_loaded;
87   }
88 
89   inline static const std::string kTokenEmbeddingMethod = "token_embedding";
90   inline static const std::string kTextModelMethod = "text_model";
91 };
92 
93 } // namespace example
94