1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker *
5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker */
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/module/module.h>
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/data_loader/file_data_loader.h>
12*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/data_loader/mmap_data_loader.h>
13*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
14*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/runtime.h>
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Worker /**
17*523fa7a6SAndroid Build Coastguard Worker * Unwrap a Result to obtain its value (direct object, not a pointer).
18*523fa7a6SAndroid Build Coastguard Worker * If the Result contains an error, propagate the error via trivial function
19*523fa7a6SAndroid Build Coastguard Worker * return. The macro wraps the object into a unique_ptr.
20*523fa7a6SAndroid Build Coastguard Worker *
21*523fa7a6SAndroid Build Coastguard Worker * Note: A function using ET_UNWRAP_UNIQUE should itself return a Result or
22*523fa7a6SAndroid Build Coastguard Worker * Error.
23*523fa7a6SAndroid Build Coastguard Worker *
24*523fa7a6SAndroid Build Coastguard Worker * @param[in] result__ Expression yielding the result to unwrap.
25*523fa7a6SAndroid Build Coastguard Worker */
26*523fa7a6SAndroid Build Coastguard Worker #define ET_UNWRAP_UNIQUE(result__) \
27*523fa7a6SAndroid Build Coastguard Worker ({ \
28*523fa7a6SAndroid Build Coastguard Worker auto et_result__ = (result__); \
29*523fa7a6SAndroid Build Coastguard Worker if (!et_result__.ok()) { \
30*523fa7a6SAndroid Build Coastguard Worker return et_result__.error(); \
31*523fa7a6SAndroid Build Coastguard Worker } \
32*523fa7a6SAndroid Build Coastguard Worker std::make_unique<std::remove_reference_t<decltype(*et_result__)>>( \
33*523fa7a6SAndroid Build Coastguard Worker std::move(*et_result__)); \
34*523fa7a6SAndroid Build Coastguard Worker })
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
37*523fa7a6SAndroid Build Coastguard Worker namespace extension {
38*523fa7a6SAndroid Build Coastguard Worker
Module(const std::string & file_path,const LoadMode load_mode,std::unique_ptr<runtime::EventTracer> event_tracer)39*523fa7a6SAndroid Build Coastguard Worker Module::Module(
40*523fa7a6SAndroid Build Coastguard Worker const std::string& file_path,
41*523fa7a6SAndroid Build Coastguard Worker const LoadMode load_mode,
42*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::EventTracer> event_tracer)
43*523fa7a6SAndroid Build Coastguard Worker : file_path_(file_path),
44*523fa7a6SAndroid Build Coastguard Worker load_mode_(load_mode),
45*523fa7a6SAndroid Build Coastguard Worker memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
46*523fa7a6SAndroid Build Coastguard Worker temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
47*523fa7a6SAndroid Build Coastguard Worker event_tracer_(std::move(event_tracer)) {
48*523fa7a6SAndroid Build Coastguard Worker runtime::runtime_init();
49*523fa7a6SAndroid Build Coastguard Worker }
50*523fa7a6SAndroid Build Coastguard Worker
Module(std::unique_ptr<runtime::DataLoader> data_loader,std::unique_ptr<runtime::MemoryAllocator> memory_allocator,std::unique_ptr<runtime::MemoryAllocator> temp_allocator,std::unique_ptr<runtime::EventTracer> event_tracer)51*523fa7a6SAndroid Build Coastguard Worker Module::Module(
52*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::DataLoader> data_loader,
53*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
54*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
55*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::EventTracer> event_tracer)
56*523fa7a6SAndroid Build Coastguard Worker : data_loader_(std::move(data_loader)),
57*523fa7a6SAndroid Build Coastguard Worker memory_allocator_(
58*523fa7a6SAndroid Build Coastguard Worker memory_allocator ? std::move(memory_allocator)
59*523fa7a6SAndroid Build Coastguard Worker : std::make_unique<MallocMemoryAllocator>()),
60*523fa7a6SAndroid Build Coastguard Worker temp_allocator_(
61*523fa7a6SAndroid Build Coastguard Worker temp_allocator ? std::move(temp_allocator)
62*523fa7a6SAndroid Build Coastguard Worker : std::make_unique<MallocMemoryAllocator>()),
63*523fa7a6SAndroid Build Coastguard Worker event_tracer_(std::move(event_tracer)) {
64*523fa7a6SAndroid Build Coastguard Worker runtime::runtime_init();
65*523fa7a6SAndroid Build Coastguard Worker }
66*523fa7a6SAndroid Build Coastguard Worker
Module(std::shared_ptr<runtime::Program> program,std::unique_ptr<runtime::MemoryAllocator> memory_allocator,std::unique_ptr<runtime::MemoryAllocator> temp_allocator,std::unique_ptr<runtime::EventTracer> event_tracer)67*523fa7a6SAndroid Build Coastguard Worker Module::Module(
68*523fa7a6SAndroid Build Coastguard Worker std::shared_ptr<runtime::Program> program,
69*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
70*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
71*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::EventTracer> event_tracer)
72*523fa7a6SAndroid Build Coastguard Worker : program_(std::move(program)),
73*523fa7a6SAndroid Build Coastguard Worker memory_allocator_(
74*523fa7a6SAndroid Build Coastguard Worker memory_allocator ? std::move(memory_allocator)
75*523fa7a6SAndroid Build Coastguard Worker : std::make_unique<MallocMemoryAllocator>()),
76*523fa7a6SAndroid Build Coastguard Worker temp_allocator_(
77*523fa7a6SAndroid Build Coastguard Worker temp_allocator ? std::move(temp_allocator)
78*523fa7a6SAndroid Build Coastguard Worker : std::make_unique<MallocMemoryAllocator>()),
79*523fa7a6SAndroid Build Coastguard Worker event_tracer_(std::move(event_tracer)) {
80*523fa7a6SAndroid Build Coastguard Worker runtime::runtime_init();
81*523fa7a6SAndroid Build Coastguard Worker }
82*523fa7a6SAndroid Build Coastguard Worker
load(const runtime::Program::Verification verification)83*523fa7a6SAndroid Build Coastguard Worker runtime::Error Module::load(const runtime::Program::Verification verification) {
84*523fa7a6SAndroid Build Coastguard Worker if (!is_loaded()) {
85*523fa7a6SAndroid Build Coastguard Worker if (!data_loader_) {
86*523fa7a6SAndroid Build Coastguard Worker switch (load_mode_) {
87*523fa7a6SAndroid Build Coastguard Worker case LoadMode::File:
88*523fa7a6SAndroid Build Coastguard Worker data_loader_ =
89*523fa7a6SAndroid Build Coastguard Worker ET_UNWRAP_UNIQUE(FileDataLoader::from(file_path_.c_str()));
90*523fa7a6SAndroid Build Coastguard Worker break;
91*523fa7a6SAndroid Build Coastguard Worker case LoadMode::Mmap:
92*523fa7a6SAndroid Build Coastguard Worker data_loader_ = ET_UNWRAP_UNIQUE(MmapDataLoader::from(
93*523fa7a6SAndroid Build Coastguard Worker file_path_.c_str(), MmapDataLoader::MlockConfig::NoMlock));
94*523fa7a6SAndroid Build Coastguard Worker break;
95*523fa7a6SAndroid Build Coastguard Worker case LoadMode::MmapUseMlock:
96*523fa7a6SAndroid Build Coastguard Worker data_loader_ =
97*523fa7a6SAndroid Build Coastguard Worker ET_UNWRAP_UNIQUE(MmapDataLoader::from(file_path_.c_str()));
98*523fa7a6SAndroid Build Coastguard Worker break;
99*523fa7a6SAndroid Build Coastguard Worker case LoadMode::MmapUseMlockIgnoreErrors:
100*523fa7a6SAndroid Build Coastguard Worker data_loader_ = ET_UNWRAP_UNIQUE(MmapDataLoader::from(
101*523fa7a6SAndroid Build Coastguard Worker file_path_.c_str(),
102*523fa7a6SAndroid Build Coastguard Worker MmapDataLoader::MlockConfig::UseMlockIgnoreErrors));
103*523fa7a6SAndroid Build Coastguard Worker break;
104*523fa7a6SAndroid Build Coastguard Worker }
105*523fa7a6SAndroid Build Coastguard Worker };
106*523fa7a6SAndroid Build Coastguard Worker auto program = ET_UNWRAP_UNIQUE(
107*523fa7a6SAndroid Build Coastguard Worker runtime::Program::load(data_loader_.get(), verification));
108*523fa7a6SAndroid Build Coastguard Worker program_ = std::shared_ptr<runtime::Program>(
109*523fa7a6SAndroid Build Coastguard Worker program.release(), [](runtime::Program* pointer) { delete pointer; });
110*523fa7a6SAndroid Build Coastguard Worker }
111*523fa7a6SAndroid Build Coastguard Worker return runtime::Error::Ok;
112*523fa7a6SAndroid Build Coastguard Worker }
113*523fa7a6SAndroid Build Coastguard Worker
method_names()114*523fa7a6SAndroid Build Coastguard Worker runtime::Result<std::unordered_set<std::string>> Module::method_names() {
115*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OK_OR_RETURN_ERROR(load());
116*523fa7a6SAndroid Build Coastguard Worker const auto method_count = program_->num_methods();
117*523fa7a6SAndroid Build Coastguard Worker std::unordered_set<std::string> result;
118*523fa7a6SAndroid Build Coastguard Worker result.reserve(method_count);
119*523fa7a6SAndroid Build Coastguard Worker
120*523fa7a6SAndroid Build Coastguard Worker for (auto index = 0; index < method_count; ++index) {
121*523fa7a6SAndroid Build Coastguard Worker result.emplace(program_->get_method_name(index).get());
122*523fa7a6SAndroid Build Coastguard Worker }
123*523fa7a6SAndroid Build Coastguard Worker return result;
124*523fa7a6SAndroid Build Coastguard Worker }
125*523fa7a6SAndroid Build Coastguard Worker
load_method(const std::string & method_name,torch::executor::EventTracer * event_tracer)126*523fa7a6SAndroid Build Coastguard Worker runtime::Error Module::load_method(
127*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name,
128*523fa7a6SAndroid Build Coastguard Worker torch::executor::EventTracer* event_tracer) {
129*523fa7a6SAndroid Build Coastguard Worker if (!is_method_loaded(method_name)) {
130*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OK_OR_RETURN_ERROR(load());
131*523fa7a6SAndroid Build Coastguard Worker
132*523fa7a6SAndroid Build Coastguard Worker MethodHolder method_holder;
133*523fa7a6SAndroid Build Coastguard Worker const auto method_metadata =
134*523fa7a6SAndroid Build Coastguard Worker ET_UNWRAP(program_->method_meta(method_name.c_str()));
135*523fa7a6SAndroid Build Coastguard Worker const auto planned_buffersCount =
136*523fa7a6SAndroid Build Coastguard Worker method_metadata.num_memory_planned_buffers();
137*523fa7a6SAndroid Build Coastguard Worker method_holder.planned_buffers.reserve(planned_buffersCount);
138*523fa7a6SAndroid Build Coastguard Worker method_holder.planned_spans.reserve(planned_buffersCount);
139*523fa7a6SAndroid Build Coastguard Worker
140*523fa7a6SAndroid Build Coastguard Worker for (auto index = 0; index < planned_buffersCount; ++index) {
141*523fa7a6SAndroid Build Coastguard Worker const auto buffer_size =
142*523fa7a6SAndroid Build Coastguard Worker method_metadata.memory_planned_buffer_size(index).get();
143*523fa7a6SAndroid Build Coastguard Worker method_holder.planned_buffers.emplace_back(buffer_size);
144*523fa7a6SAndroid Build Coastguard Worker method_holder.planned_spans.emplace_back(
145*523fa7a6SAndroid Build Coastguard Worker method_holder.planned_buffers.back().data(), buffer_size);
146*523fa7a6SAndroid Build Coastguard Worker }
147*523fa7a6SAndroid Build Coastguard Worker method_holder.planned_memory =
148*523fa7a6SAndroid Build Coastguard Worker std::make_unique<runtime::HierarchicalAllocator>(runtime::Span(
149*523fa7a6SAndroid Build Coastguard Worker method_holder.planned_spans.data(),
150*523fa7a6SAndroid Build Coastguard Worker method_holder.planned_spans.size()));
151*523fa7a6SAndroid Build Coastguard Worker method_holder.memory_manager = std::make_unique<runtime::MemoryManager>(
152*523fa7a6SAndroid Build Coastguard Worker memory_allocator_.get(),
153*523fa7a6SAndroid Build Coastguard Worker method_holder.planned_memory.get(),
154*523fa7a6SAndroid Build Coastguard Worker temp_allocator_.get());
155*523fa7a6SAndroid Build Coastguard Worker method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
156*523fa7a6SAndroid Build Coastguard Worker method_name.c_str(),
157*523fa7a6SAndroid Build Coastguard Worker method_holder.memory_manager.get(),
158*523fa7a6SAndroid Build Coastguard Worker event_tracer ? event_tracer : this->event_tracer()));
159*523fa7a6SAndroid Build Coastguard Worker method_holder.inputs.resize(method_holder.method->inputs_size());
160*523fa7a6SAndroid Build Coastguard Worker methods_.emplace(method_name, std::move(method_holder));
161*523fa7a6SAndroid Build Coastguard Worker }
162*523fa7a6SAndroid Build Coastguard Worker return runtime::Error::Ok;
163*523fa7a6SAndroid Build Coastguard Worker }
164*523fa7a6SAndroid Build Coastguard Worker
method_meta(const std::string & method_name)165*523fa7a6SAndroid Build Coastguard Worker runtime::Result<runtime::MethodMeta> Module::method_meta(
166*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name) {
167*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
168*523fa7a6SAndroid Build Coastguard Worker return methods_.at(method_name).method->method_meta();
169*523fa7a6SAndroid Build Coastguard Worker }
170*523fa7a6SAndroid Build Coastguard Worker
execute(const std::string & method_name,const std::vector<runtime::EValue> & input_values)171*523fa7a6SAndroid Build Coastguard Worker runtime::Result<std::vector<runtime::EValue>> Module::execute(
172*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name,
173*523fa7a6SAndroid Build Coastguard Worker const std::vector<runtime::EValue>& input_values) {
174*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
175*523fa7a6SAndroid Build Coastguard Worker auto& method = methods_.at(method_name).method;
176*523fa7a6SAndroid Build Coastguard Worker auto& inputs = methods_.at(method_name).inputs;
177*523fa7a6SAndroid Build Coastguard Worker
178*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < input_values.size(); ++i) {
179*523fa7a6SAndroid Build Coastguard Worker if (!input_values[i].isNone()) {
180*523fa7a6SAndroid Build Coastguard Worker inputs[i] = input_values[i];
181*523fa7a6SAndroid Build Coastguard Worker }
182*523fa7a6SAndroid Build Coastguard Worker }
183*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < inputs.size(); ++i) {
184*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
185*523fa7a6SAndroid Build Coastguard Worker !inputs[i].isNone(), InvalidArgument, "input %zu is none", i);
186*523fa7a6SAndroid Build Coastguard Worker }
187*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OK_OR_RETURN_ERROR(method->set_inputs(
188*523fa7a6SAndroid Build Coastguard Worker exec_aten::ArrayRef<runtime::EValue>(inputs.data(), inputs.size())));
189*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OK_OR_RETURN_ERROR(method->execute());
190*523fa7a6SAndroid Build Coastguard Worker
191*523fa7a6SAndroid Build Coastguard Worker const auto outputs_size = method->outputs_size();
192*523fa7a6SAndroid Build Coastguard Worker std::vector<runtime::EValue> outputs(outputs_size);
193*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OK_OR_RETURN_ERROR(
194*523fa7a6SAndroid Build Coastguard Worker method->get_outputs(outputs.data(), outputs_size));
195*523fa7a6SAndroid Build Coastguard Worker
196*523fa7a6SAndroid Build Coastguard Worker return outputs;
197*523fa7a6SAndroid Build Coastguard Worker }
198*523fa7a6SAndroid Build Coastguard Worker
set_input(const std::string & method_name,const runtime::EValue & input_value,size_t input_index)199*523fa7a6SAndroid Build Coastguard Worker runtime::Error Module::set_input(
200*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name,
201*523fa7a6SAndroid Build Coastguard Worker const runtime::EValue& input_value,
202*523fa7a6SAndroid Build Coastguard Worker size_t input_index) {
203*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
204*523fa7a6SAndroid Build Coastguard Worker methods_.at(method_name).inputs.at(input_index) = input_value;
205*523fa7a6SAndroid Build Coastguard Worker return runtime::Error::Ok;
206*523fa7a6SAndroid Build Coastguard Worker }
207*523fa7a6SAndroid Build Coastguard Worker
set_inputs(const std::string & method_name,const std::vector<runtime::EValue> & input_values)208*523fa7a6SAndroid Build Coastguard Worker runtime::Error Module::set_inputs(
209*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name,
210*523fa7a6SAndroid Build Coastguard Worker const std::vector<runtime::EValue>& input_values) {
211*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
212*523fa7a6SAndroid Build Coastguard Worker auto& inputs = methods_.at(method_name).inputs;
213*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
214*523fa7a6SAndroid Build Coastguard Worker inputs.size() == input_values.size(),
215*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
216*523fa7a6SAndroid Build Coastguard Worker "input size: %zu does not match method input size: %zu",
217*523fa7a6SAndroid Build Coastguard Worker input_values.size(),
218*523fa7a6SAndroid Build Coastguard Worker inputs.size());
219*523fa7a6SAndroid Build Coastguard Worker inputs = input_values;
220*523fa7a6SAndroid Build Coastguard Worker return runtime::Error::Ok;
221*523fa7a6SAndroid Build Coastguard Worker }
222*523fa7a6SAndroid Build Coastguard Worker
set_output(const std::string & method_name,runtime::EValue output_value,size_t output_index)223*523fa7a6SAndroid Build Coastguard Worker runtime::Error Module::set_output(
224*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name,
225*523fa7a6SAndroid Build Coastguard Worker runtime::EValue output_value,
226*523fa7a6SAndroid Build Coastguard Worker size_t output_index) {
227*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
228*523fa7a6SAndroid Build Coastguard Worker auto& method = methods_.at(method_name).method;
229*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
230*523fa7a6SAndroid Build Coastguard Worker output_value.isTensor(),
231*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
232*523fa7a6SAndroid Build Coastguard Worker "output type: %zu is not tensor",
233*523fa7a6SAndroid Build Coastguard Worker (size_t)output_value.tag);
234*523fa7a6SAndroid Build Coastguard Worker const auto& output_tensor = output_value.toTensor();
235*523fa7a6SAndroid Build Coastguard Worker return method->set_output_data_ptr(
236*523fa7a6SAndroid Build Coastguard Worker output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index);
237*523fa7a6SAndroid Build Coastguard Worker }
238*523fa7a6SAndroid Build Coastguard Worker
239*523fa7a6SAndroid Build Coastguard Worker } // namespace extension
240*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
241