xref: /aosp_15_r20/external/executorch/extension/module/module.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/module/module.h>
10*523fa7a6SAndroid Build Coastguard Worker 
11*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/data_loader/file_data_loader.h>
12*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/data_loader/mmap_data_loader.h>
13*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
14*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/runtime.h>
15*523fa7a6SAndroid Build Coastguard Worker 
16*523fa7a6SAndroid Build Coastguard Worker /**
17*523fa7a6SAndroid Build Coastguard Worker  * Unwrap a Result to obtain its value (direct object, not a pointer).
18*523fa7a6SAndroid Build Coastguard Worker  * If the Result contains an error, propagate the error via trivial function
19*523fa7a6SAndroid Build Coastguard Worker  * return. The macro wraps the object into a unique_ptr.
20*523fa7a6SAndroid Build Coastguard Worker  *
21*523fa7a6SAndroid Build Coastguard Worker  * Note: A function using ET_UNWRAP_UNIQUE should itself return a Result or
22*523fa7a6SAndroid Build Coastguard Worker  * Error.
23*523fa7a6SAndroid Build Coastguard Worker  *
24*523fa7a6SAndroid Build Coastguard Worker  * @param[in] result__ Expression yielding the result to unwrap.
25*523fa7a6SAndroid Build Coastguard Worker  */
26*523fa7a6SAndroid Build Coastguard Worker #define ET_UNWRAP_UNIQUE(result__)                                     \
27*523fa7a6SAndroid Build Coastguard Worker   ({                                                                   \
28*523fa7a6SAndroid Build Coastguard Worker     auto et_result__ = (result__);                                     \
29*523fa7a6SAndroid Build Coastguard Worker     if (!et_result__.ok()) {                                           \
30*523fa7a6SAndroid Build Coastguard Worker       return et_result__.error();                                      \
31*523fa7a6SAndroid Build Coastguard Worker     }                                                                  \
32*523fa7a6SAndroid Build Coastguard Worker     std::make_unique<std::remove_reference_t<decltype(*et_result__)>>( \
33*523fa7a6SAndroid Build Coastguard Worker         std::move(*et_result__));                                      \
34*523fa7a6SAndroid Build Coastguard Worker   })
35*523fa7a6SAndroid Build Coastguard Worker 
36*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
37*523fa7a6SAndroid Build Coastguard Worker namespace extension {
38*523fa7a6SAndroid Build Coastguard Worker 
Module(const std::string & file_path,const LoadMode load_mode,std::unique_ptr<runtime::EventTracer> event_tracer)39*523fa7a6SAndroid Build Coastguard Worker Module::Module(
40*523fa7a6SAndroid Build Coastguard Worker     const std::string& file_path,
41*523fa7a6SAndroid Build Coastguard Worker     const LoadMode load_mode,
42*523fa7a6SAndroid Build Coastguard Worker     std::unique_ptr<runtime::EventTracer> event_tracer)
43*523fa7a6SAndroid Build Coastguard Worker     : file_path_(file_path),
44*523fa7a6SAndroid Build Coastguard Worker       load_mode_(load_mode),
45*523fa7a6SAndroid Build Coastguard Worker       memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
46*523fa7a6SAndroid Build Coastguard Worker       temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
47*523fa7a6SAndroid Build Coastguard Worker       event_tracer_(std::move(event_tracer)) {
48*523fa7a6SAndroid Build Coastguard Worker   runtime::runtime_init();
49*523fa7a6SAndroid Build Coastguard Worker }
50*523fa7a6SAndroid Build Coastguard Worker 
Module(std::unique_ptr<runtime::DataLoader> data_loader,std::unique_ptr<runtime::MemoryAllocator> memory_allocator,std::unique_ptr<runtime::MemoryAllocator> temp_allocator,std::unique_ptr<runtime::EventTracer> event_tracer)51*523fa7a6SAndroid Build Coastguard Worker Module::Module(
52*523fa7a6SAndroid Build Coastguard Worker     std::unique_ptr<runtime::DataLoader> data_loader,
53*523fa7a6SAndroid Build Coastguard Worker     std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
54*523fa7a6SAndroid Build Coastguard Worker     std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
55*523fa7a6SAndroid Build Coastguard Worker     std::unique_ptr<runtime::EventTracer> event_tracer)
56*523fa7a6SAndroid Build Coastguard Worker     : data_loader_(std::move(data_loader)),
57*523fa7a6SAndroid Build Coastguard Worker       memory_allocator_(
58*523fa7a6SAndroid Build Coastguard Worker           memory_allocator ? std::move(memory_allocator)
59*523fa7a6SAndroid Build Coastguard Worker                            : std::make_unique<MallocMemoryAllocator>()),
60*523fa7a6SAndroid Build Coastguard Worker       temp_allocator_(
61*523fa7a6SAndroid Build Coastguard Worker           temp_allocator ? std::move(temp_allocator)
62*523fa7a6SAndroid Build Coastguard Worker                          : std::make_unique<MallocMemoryAllocator>()),
63*523fa7a6SAndroid Build Coastguard Worker       event_tracer_(std::move(event_tracer)) {
64*523fa7a6SAndroid Build Coastguard Worker   runtime::runtime_init();
65*523fa7a6SAndroid Build Coastguard Worker }
66*523fa7a6SAndroid Build Coastguard Worker 
Module(std::shared_ptr<runtime::Program> program,std::unique_ptr<runtime::MemoryAllocator> memory_allocator,std::unique_ptr<runtime::MemoryAllocator> temp_allocator,std::unique_ptr<runtime::EventTracer> event_tracer)67*523fa7a6SAndroid Build Coastguard Worker Module::Module(
68*523fa7a6SAndroid Build Coastguard Worker     std::shared_ptr<runtime::Program> program,
69*523fa7a6SAndroid Build Coastguard Worker     std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
70*523fa7a6SAndroid Build Coastguard Worker     std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
71*523fa7a6SAndroid Build Coastguard Worker     std::unique_ptr<runtime::EventTracer> event_tracer)
72*523fa7a6SAndroid Build Coastguard Worker     : program_(std::move(program)),
73*523fa7a6SAndroid Build Coastguard Worker       memory_allocator_(
74*523fa7a6SAndroid Build Coastguard Worker           memory_allocator ? std::move(memory_allocator)
75*523fa7a6SAndroid Build Coastguard Worker                            : std::make_unique<MallocMemoryAllocator>()),
76*523fa7a6SAndroid Build Coastguard Worker       temp_allocator_(
77*523fa7a6SAndroid Build Coastguard Worker           temp_allocator ? std::move(temp_allocator)
78*523fa7a6SAndroid Build Coastguard Worker                          : std::make_unique<MallocMemoryAllocator>()),
79*523fa7a6SAndroid Build Coastguard Worker       event_tracer_(std::move(event_tracer)) {
80*523fa7a6SAndroid Build Coastguard Worker   runtime::runtime_init();
81*523fa7a6SAndroid Build Coastguard Worker }
82*523fa7a6SAndroid Build Coastguard Worker 
load(const runtime::Program::Verification verification)83*523fa7a6SAndroid Build Coastguard Worker runtime::Error Module::load(const runtime::Program::Verification verification) {
84*523fa7a6SAndroid Build Coastguard Worker   if (!is_loaded()) {
85*523fa7a6SAndroid Build Coastguard Worker     if (!data_loader_) {
86*523fa7a6SAndroid Build Coastguard Worker       switch (load_mode_) {
87*523fa7a6SAndroid Build Coastguard Worker         case LoadMode::File:
88*523fa7a6SAndroid Build Coastguard Worker           data_loader_ =
89*523fa7a6SAndroid Build Coastguard Worker               ET_UNWRAP_UNIQUE(FileDataLoader::from(file_path_.c_str()));
90*523fa7a6SAndroid Build Coastguard Worker           break;
91*523fa7a6SAndroid Build Coastguard Worker         case LoadMode::Mmap:
92*523fa7a6SAndroid Build Coastguard Worker           data_loader_ = ET_UNWRAP_UNIQUE(MmapDataLoader::from(
93*523fa7a6SAndroid Build Coastguard Worker               file_path_.c_str(), MmapDataLoader::MlockConfig::NoMlock));
94*523fa7a6SAndroid Build Coastguard Worker           break;
95*523fa7a6SAndroid Build Coastguard Worker         case LoadMode::MmapUseMlock:
96*523fa7a6SAndroid Build Coastguard Worker           data_loader_ =
97*523fa7a6SAndroid Build Coastguard Worker               ET_UNWRAP_UNIQUE(MmapDataLoader::from(file_path_.c_str()));
98*523fa7a6SAndroid Build Coastguard Worker           break;
99*523fa7a6SAndroid Build Coastguard Worker         case LoadMode::MmapUseMlockIgnoreErrors:
100*523fa7a6SAndroid Build Coastguard Worker           data_loader_ = ET_UNWRAP_UNIQUE(MmapDataLoader::from(
101*523fa7a6SAndroid Build Coastguard Worker               file_path_.c_str(),
102*523fa7a6SAndroid Build Coastguard Worker               MmapDataLoader::MlockConfig::UseMlockIgnoreErrors));
103*523fa7a6SAndroid Build Coastguard Worker           break;
104*523fa7a6SAndroid Build Coastguard Worker       }
105*523fa7a6SAndroid Build Coastguard Worker     };
106*523fa7a6SAndroid Build Coastguard Worker     auto program = ET_UNWRAP_UNIQUE(
107*523fa7a6SAndroid Build Coastguard Worker         runtime::Program::load(data_loader_.get(), verification));
108*523fa7a6SAndroid Build Coastguard Worker     program_ = std::shared_ptr<runtime::Program>(
109*523fa7a6SAndroid Build Coastguard Worker         program.release(), [](runtime::Program* pointer) { delete pointer; });
110*523fa7a6SAndroid Build Coastguard Worker   }
111*523fa7a6SAndroid Build Coastguard Worker   return runtime::Error::Ok;
112*523fa7a6SAndroid Build Coastguard Worker }
113*523fa7a6SAndroid Build Coastguard Worker 
method_names()114*523fa7a6SAndroid Build Coastguard Worker runtime::Result<std::unordered_set<std::string>> Module::method_names() {
115*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OK_OR_RETURN_ERROR(load());
116*523fa7a6SAndroid Build Coastguard Worker   const auto method_count = program_->num_methods();
117*523fa7a6SAndroid Build Coastguard Worker   std::unordered_set<std::string> result;
118*523fa7a6SAndroid Build Coastguard Worker   result.reserve(method_count);
119*523fa7a6SAndroid Build Coastguard Worker 
120*523fa7a6SAndroid Build Coastguard Worker   for (auto index = 0; index < method_count; ++index) {
121*523fa7a6SAndroid Build Coastguard Worker     result.emplace(program_->get_method_name(index).get());
122*523fa7a6SAndroid Build Coastguard Worker   }
123*523fa7a6SAndroid Build Coastguard Worker   return result;
124*523fa7a6SAndroid Build Coastguard Worker }
125*523fa7a6SAndroid Build Coastguard Worker 
load_method(const std::string & method_name,torch::executor::EventTracer * event_tracer)126*523fa7a6SAndroid Build Coastguard Worker runtime::Error Module::load_method(
127*523fa7a6SAndroid Build Coastguard Worker     const std::string& method_name,
128*523fa7a6SAndroid Build Coastguard Worker     torch::executor::EventTracer* event_tracer) {
129*523fa7a6SAndroid Build Coastguard Worker   if (!is_method_loaded(method_name)) {
130*523fa7a6SAndroid Build Coastguard Worker     ET_CHECK_OK_OR_RETURN_ERROR(load());
131*523fa7a6SAndroid Build Coastguard Worker 
132*523fa7a6SAndroid Build Coastguard Worker     MethodHolder method_holder;
133*523fa7a6SAndroid Build Coastguard Worker     const auto method_metadata =
134*523fa7a6SAndroid Build Coastguard Worker         ET_UNWRAP(program_->method_meta(method_name.c_str()));
135*523fa7a6SAndroid Build Coastguard Worker     const auto planned_buffersCount =
136*523fa7a6SAndroid Build Coastguard Worker         method_metadata.num_memory_planned_buffers();
137*523fa7a6SAndroid Build Coastguard Worker     method_holder.planned_buffers.reserve(planned_buffersCount);
138*523fa7a6SAndroid Build Coastguard Worker     method_holder.planned_spans.reserve(planned_buffersCount);
139*523fa7a6SAndroid Build Coastguard Worker 
140*523fa7a6SAndroid Build Coastguard Worker     for (auto index = 0; index < planned_buffersCount; ++index) {
141*523fa7a6SAndroid Build Coastguard Worker       const auto buffer_size =
142*523fa7a6SAndroid Build Coastguard Worker           method_metadata.memory_planned_buffer_size(index).get();
143*523fa7a6SAndroid Build Coastguard Worker       method_holder.planned_buffers.emplace_back(buffer_size);
144*523fa7a6SAndroid Build Coastguard Worker       method_holder.planned_spans.emplace_back(
145*523fa7a6SAndroid Build Coastguard Worker           method_holder.planned_buffers.back().data(), buffer_size);
146*523fa7a6SAndroid Build Coastguard Worker     }
147*523fa7a6SAndroid Build Coastguard Worker     method_holder.planned_memory =
148*523fa7a6SAndroid Build Coastguard Worker         std::make_unique<runtime::HierarchicalAllocator>(runtime::Span(
149*523fa7a6SAndroid Build Coastguard Worker             method_holder.planned_spans.data(),
150*523fa7a6SAndroid Build Coastguard Worker             method_holder.planned_spans.size()));
151*523fa7a6SAndroid Build Coastguard Worker     method_holder.memory_manager = std::make_unique<runtime::MemoryManager>(
152*523fa7a6SAndroid Build Coastguard Worker         memory_allocator_.get(),
153*523fa7a6SAndroid Build Coastguard Worker         method_holder.planned_memory.get(),
154*523fa7a6SAndroid Build Coastguard Worker         temp_allocator_.get());
155*523fa7a6SAndroid Build Coastguard Worker     method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
156*523fa7a6SAndroid Build Coastguard Worker         method_name.c_str(),
157*523fa7a6SAndroid Build Coastguard Worker         method_holder.memory_manager.get(),
158*523fa7a6SAndroid Build Coastguard Worker         event_tracer ? event_tracer : this->event_tracer()));
159*523fa7a6SAndroid Build Coastguard Worker     method_holder.inputs.resize(method_holder.method->inputs_size());
160*523fa7a6SAndroid Build Coastguard Worker     methods_.emplace(method_name, std::move(method_holder));
161*523fa7a6SAndroid Build Coastguard Worker   }
162*523fa7a6SAndroid Build Coastguard Worker   return runtime::Error::Ok;
163*523fa7a6SAndroid Build Coastguard Worker }
164*523fa7a6SAndroid Build Coastguard Worker 
method_meta(const std::string & method_name)165*523fa7a6SAndroid Build Coastguard Worker runtime::Result<runtime::MethodMeta> Module::method_meta(
166*523fa7a6SAndroid Build Coastguard Worker     const std::string& method_name) {
167*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
168*523fa7a6SAndroid Build Coastguard Worker   return methods_.at(method_name).method->method_meta();
169*523fa7a6SAndroid Build Coastguard Worker }
170*523fa7a6SAndroid Build Coastguard Worker 
execute(const std::string & method_name,const std::vector<runtime::EValue> & input_values)171*523fa7a6SAndroid Build Coastguard Worker runtime::Result<std::vector<runtime::EValue>> Module::execute(
172*523fa7a6SAndroid Build Coastguard Worker     const std::string& method_name,
173*523fa7a6SAndroid Build Coastguard Worker     const std::vector<runtime::EValue>& input_values) {
174*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
175*523fa7a6SAndroid Build Coastguard Worker   auto& method = methods_.at(method_name).method;
176*523fa7a6SAndroid Build Coastguard Worker   auto& inputs = methods_.at(method_name).inputs;
177*523fa7a6SAndroid Build Coastguard Worker 
178*523fa7a6SAndroid Build Coastguard Worker   for (size_t i = 0; i < input_values.size(); ++i) {
179*523fa7a6SAndroid Build Coastguard Worker     if (!input_values[i].isNone()) {
180*523fa7a6SAndroid Build Coastguard Worker       inputs[i] = input_values[i];
181*523fa7a6SAndroid Build Coastguard Worker     }
182*523fa7a6SAndroid Build Coastguard Worker   }
183*523fa7a6SAndroid Build Coastguard Worker   for (size_t i = 0; i < inputs.size(); ++i) {
184*523fa7a6SAndroid Build Coastguard Worker     ET_CHECK_OR_RETURN_ERROR(
185*523fa7a6SAndroid Build Coastguard Worker         !inputs[i].isNone(), InvalidArgument, "input %zu is none", i);
186*523fa7a6SAndroid Build Coastguard Worker   }
187*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OK_OR_RETURN_ERROR(method->set_inputs(
188*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ArrayRef<runtime::EValue>(inputs.data(), inputs.size())));
189*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OK_OR_RETURN_ERROR(method->execute());
190*523fa7a6SAndroid Build Coastguard Worker 
191*523fa7a6SAndroid Build Coastguard Worker   const auto outputs_size = method->outputs_size();
192*523fa7a6SAndroid Build Coastguard Worker   std::vector<runtime::EValue> outputs(outputs_size);
193*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OK_OR_RETURN_ERROR(
194*523fa7a6SAndroid Build Coastguard Worker       method->get_outputs(outputs.data(), outputs_size));
195*523fa7a6SAndroid Build Coastguard Worker 
196*523fa7a6SAndroid Build Coastguard Worker   return outputs;
197*523fa7a6SAndroid Build Coastguard Worker }
198*523fa7a6SAndroid Build Coastguard Worker 
set_input(const std::string & method_name,const runtime::EValue & input_value,size_t input_index)199*523fa7a6SAndroid Build Coastguard Worker runtime::Error Module::set_input(
200*523fa7a6SAndroid Build Coastguard Worker     const std::string& method_name,
201*523fa7a6SAndroid Build Coastguard Worker     const runtime::EValue& input_value,
202*523fa7a6SAndroid Build Coastguard Worker     size_t input_index) {
203*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
204*523fa7a6SAndroid Build Coastguard Worker   methods_.at(method_name).inputs.at(input_index) = input_value;
205*523fa7a6SAndroid Build Coastguard Worker   return runtime::Error::Ok;
206*523fa7a6SAndroid Build Coastguard Worker }
207*523fa7a6SAndroid Build Coastguard Worker 
set_inputs(const std::string & method_name,const std::vector<runtime::EValue> & input_values)208*523fa7a6SAndroid Build Coastguard Worker runtime::Error Module::set_inputs(
209*523fa7a6SAndroid Build Coastguard Worker     const std::string& method_name,
210*523fa7a6SAndroid Build Coastguard Worker     const std::vector<runtime::EValue>& input_values) {
211*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
212*523fa7a6SAndroid Build Coastguard Worker   auto& inputs = methods_.at(method_name).inputs;
213*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
214*523fa7a6SAndroid Build Coastguard Worker       inputs.size() == input_values.size(),
215*523fa7a6SAndroid Build Coastguard Worker       InvalidArgument,
216*523fa7a6SAndroid Build Coastguard Worker       "input size: %zu does not match method input size: %zu",
217*523fa7a6SAndroid Build Coastguard Worker       input_values.size(),
218*523fa7a6SAndroid Build Coastguard Worker       inputs.size());
219*523fa7a6SAndroid Build Coastguard Worker   inputs = input_values;
220*523fa7a6SAndroid Build Coastguard Worker   return runtime::Error::Ok;
221*523fa7a6SAndroid Build Coastguard Worker }
222*523fa7a6SAndroid Build Coastguard Worker 
set_output(const std::string & method_name,runtime::EValue output_value,size_t output_index)223*523fa7a6SAndroid Build Coastguard Worker runtime::Error Module::set_output(
224*523fa7a6SAndroid Build Coastguard Worker     const std::string& method_name,
225*523fa7a6SAndroid Build Coastguard Worker     runtime::EValue output_value,
226*523fa7a6SAndroid Build Coastguard Worker     size_t output_index) {
227*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
228*523fa7a6SAndroid Build Coastguard Worker   auto& method = methods_.at(method_name).method;
229*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
230*523fa7a6SAndroid Build Coastguard Worker       output_value.isTensor(),
231*523fa7a6SAndroid Build Coastguard Worker       InvalidArgument,
232*523fa7a6SAndroid Build Coastguard Worker       "output type: %zu is not tensor",
233*523fa7a6SAndroid Build Coastguard Worker       (size_t)output_value.tag);
234*523fa7a6SAndroid Build Coastguard Worker   const auto& output_tensor = output_value.toTensor();
235*523fa7a6SAndroid Build Coastguard Worker   return method->set_output_data_ptr(
236*523fa7a6SAndroid Build Coastguard Worker       output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index);
237*523fa7a6SAndroid Build Coastguard Worker }
238*523fa7a6SAndroid Build Coastguard Worker 
239*523fa7a6SAndroid Build Coastguard Worker } // namespace extension
240*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
241