xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/aoti_runtime/interface.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
2 #include <torch/csrc/inductor/aoti_runtime/interface.h>
3 #include <torch/csrc/inductor/aoti_runtime/model_container.h>
4 #include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
5 #include <torch/csrc/inductor/aoti_runtime/thread_local.h>
6 
7 #include <iostream>
8 #include <sstream>
9 #include <stdexcept>
10 #include <vector>
11 
12 #define CONVERT_EXCEPTION_TO_ERROR_CODE(...)                 \
13   try {                                                      \
14     __VA_ARGS__                                              \
15   } catch (const std::exception& e) {                        \
16     std::cerr << "Error: " << e.what() << std::endl;         \
17     return AOTI_RUNTIME_FAILURE;                             \
18   } catch (...) {                                            \
19     std::cerr << "Unknown exception occurred." << std::endl; \
20     return AOTI_RUNTIME_FAILURE;                             \
21   }                                                          \
22   return AOTI_RUNTIME_SUCCESS;
23 
24 #define AOTI_VECTOR_SIZE_CHECK(actual_size, expected_size, name)  \
25   do {                                                            \
26     AOTI_RUNTIME_CHECK(                                           \
27         actual_size == expected_size,                             \
28         "expected " + std::string(name) + " vector size to be " + \
29             std::to_string(expected_size) + ", but got " +        \
30             std::to_string(actual_size));                         \
31   } while (0)
32 
33 // AOTInductor uses at::addmm_out, which doesn't supports
34 // arguments that requires gradient. For this reason, we
35 // enforce no_grad context for run APIs.
36 //
37 // A RAII, thread local (!) guard that enables or disables grad mode upon
38 // construction, and sets it back to the original value upon destruction.
39 struct AOTINoGradGuard {
AOTINoGradGuardAOTINoGradGuard40   AOTINoGradGuard() : prev_mode(aoti_torch_grad_mode_is_enabled()) {
41     aoti_torch_grad_mode_set_enabled(false);
42   }
~AOTINoGradGuardAOTINoGradGuard43   ~AOTINoGradGuard() {
44     aoti_torch_grad_mode_set_enabled(prev_mode);
45   }
46   bool prev_mode;
47 };
48 
49 extern "C" {
50 
AOTInductorModelContainerCreate(AOTInductorModelContainerHandle * container_handle,size_t num_models,bool is_cpu,const char * cubin_dir)51 AOTIRuntimeError AOTInductorModelContainerCreate(
52     AOTInductorModelContainerHandle* container_handle,
53     size_t num_models,
54     bool is_cpu,
55     const char* cubin_dir) {
56       return AOTInductorModelContainerCreateWithDevice(
57         container_handle,
58         num_models,
59         is_cpu ? "cpu" : "cuda",
60         cubin_dir);
61 }
62 
AOTInductorModelContainerCreateWithDevice(AOTInductorModelContainerHandle * container_handle,size_t num_models,const char * device_str,const char * cubin_dir)63 AOTIRuntimeError AOTInductorModelContainerCreateWithDevice(
64     AOTInductorModelContainerHandle* container_handle,
65     size_t num_models,
66     const char* device_str,
67     const char* cubin_dir) {
68   if (num_models == 0) {
69     std::cerr << "Error: num_models must be positive, but got 0" << std::endl;
70     return AOTI_RUNTIME_FAILURE;
71   }
72   CONVERT_EXCEPTION_TO_ERROR_CODE({
73     std::optional<std::string> cubin_dir_opt;
74     if (cubin_dir != nullptr) {
75       cubin_dir_opt.emplace(cubin_dir);
76     }
77     auto* container = new torch::aot_inductor::AOTInductorModelContainer(
78         num_models, std::string(device_str), cubin_dir_opt);
79     *container_handle =
80         reinterpret_cast<AOTInductorModelContainerHandle>(container);
81   })
82 }
83 
AOTInductorModelContainerDelete(AOTInductorModelContainerHandle container_handle)84 AOTIRuntimeError AOTInductorModelContainerDelete(
85     AOTInductorModelContainerHandle container_handle) {
86   CONVERT_EXCEPTION_TO_ERROR_CODE({
87     auto* container =
88         reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
89             container_handle);
90     delete container;
91   });
92 }
93 
AOTInductorModelContainerRun(AOTInductorModelContainerHandle container_handle,AtenTensorHandle * input_handles,size_t num_inputs,AtenTensorHandle * output_handles,size_t num_outputs,AOTInductorStreamHandle stream_handle,AOTIProxyExecutorHandle proxy_executor_handle)94 AOTIRuntimeError AOTInductorModelContainerRun(
95     AOTInductorModelContainerHandle container_handle,
96     AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles
97                                      // are stolen; the array itself is borrowed
98     size_t num_inputs,
99     AtenTensorHandle*
100         output_handles, // array for writing output AtenTensorHandle; handles
101                         // will be stolen by the caller; the array itself is
102                         // borrowed
103     size_t num_outputs,
104     AOTInductorStreamHandle stream_handle,
105     AOTIProxyExecutorHandle proxy_executor_handle) {
106   auto* container =
107       reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
108           container_handle);
109   AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs");
110   AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs");
111 
112   auto stream =
113       reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
114   CONVERT_EXCEPTION_TO_ERROR_CODE({
115     AOTINoGradGuard guard;
116     container->run(
117         input_handles, output_handles, stream, proxy_executor_handle);
118   })
119 }
120 
AOTInductorModelContainerGetNumConstants(AOTInductorModelContainerHandle container_handle,size_t * num_constants)121 AOTIRuntimeError AOTInductorModelContainerGetNumConstants(
122     AOTInductorModelContainerHandle container_handle,
123     size_t* num_constants) {
124   auto* container =
125       reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
126           container_handle);
127   CONVERT_EXCEPTION_TO_ERROR_CODE(
128     { *num_constants = container->num_constants(); })
129 }
130 
AOTInductorModelContainerGetConstantName(AOTInductorModelContainerHandle container_handle,size_t idx,const char ** name)131 AOTIRuntimeError AOTInductorModelContainerGetConstantName(
132     AOTInductorModelContainerHandle container_handle,
133     size_t idx,
134     const char** name) {
135   auto* container =
136       reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
137           container_handle);
138   CONVERT_EXCEPTION_TO_ERROR_CODE(
139     { *name = container->constant_name(idx); })
140 }
141 
AOTInductorModelContainerGetConstantOriginalFQN(AOTInductorModelContainerHandle container_handle,size_t idx,const char ** original_fqn)142 AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN(
143     AOTInductorModelContainerHandle container_handle,
144     size_t idx,
145     const char** original_fqn) {
146   auto* container =
147       reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
148           container_handle);
149   CONVERT_EXCEPTION_TO_ERROR_CODE(
150     { *original_fqn = container->constant_original_fqn(idx); })
151 }
152 
AOTInductorModelContainerGetConstantFromFolded(AOTInductorModelContainerHandle container_handle,size_t idx,bool * from_folded)153 AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded(
154     AOTInductorModelContainerHandle container_handle,
155     size_t idx,
156     bool* from_folded) {
157   auto* container =
158       reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(container_handle);
159   CONVERT_EXCEPTION_TO_ERROR_CODE({ *from_folded = container->constant_from_folded(idx); })
160 }
161 
AOTInductorModelContainerGetConstantDtype(AOTInductorModelContainerHandle container_handle,size_t idx,int32_t * dtype)162 AOTIRuntimeError AOTInductorModelContainerGetConstantDtype(
163     AOTInductorModelContainerHandle container_handle,
164     size_t idx,
165     int32_t* dtype) {
166   auto* container =
167       reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
168           container_handle);
169   CONVERT_EXCEPTION_TO_ERROR_CODE(
170     { *dtype = container->constant_dtype(idx); })
171 }
172 
AOTInductorModelContainerUpdateConstantBuffer(AOTInductorModelContainerHandle container_handle,AOTInductorConstantMapHandle constant_map_handle,bool use_inactive,bool validate_full_update)173 AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer(
174     AOTInductorModelContainerHandle container_handle,
175     AOTInductorConstantMapHandle constant_map_handle,
176     bool use_inactive,
177     bool validate_full_update) {
178   auto* container =
179       reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
180           container_handle);
181   auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
182   CONVERT_EXCEPTION_TO_ERROR_CODE({
183     container->update_constant_buffer(
184         *input_map, use_inactive, validate_full_update);
185   })
186 }
187 
AOTInductorModelContainerUpdateInactiveConstantBuffer(AOTInductorModelContainerHandle container_handle,AOTInductorConstantMapHandle constant_map_handle)188 AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer(
189     AOTInductorModelContainerHandle container_handle,
190     AOTInductorConstantMapHandle constant_map_handle) {
191   return AOTInductorModelContainerUpdateConstantBuffer(container_handle,
192           constant_map_handle,
193           /*use_inactive*/ true,
194           /*validate_full_update*/ true);
195 }
196 
AOTInductorModelContainerRunConstantFolding(AOTInductorModelContainerHandle container_handle,bool use_inactive,AOTInductorStreamHandle stream_handle,AOTIProxyExecutorHandle proxy_executor_handle)197 AOTIRuntimeError AOTInductorModelContainerRunConstantFolding(
198     AOTInductorModelContainerHandle container_handle,
199     bool use_inactive,
200     AOTInductorStreamHandle stream_handle,
201     AOTIProxyExecutorHandle proxy_executor_handle) {
202   auto* container =
203       reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
204           container_handle);
205   auto stream =
206       reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
207   CONVERT_EXCEPTION_TO_ERROR_CODE({
208     AOTINoGradGuard guard;
209     container->run_const_fold(use_inactive, stream, proxy_executor_handle);
210   })
211 }
212 
AOTInductorModelContainerSwapConstantBuffer(AOTInductorModelContainerHandle container_handle)213 AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer(
214     AOTInductorModelContainerHandle container_handle) {
215   auto* container =
216       reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
217           container_handle);
218   CONVERT_EXCEPTION_TO_ERROR_CODE({
219     container->swap_constant_buffer();
220   })
221 }
222 
AOTInductorModelContainerGetNumInputs(AOTInductorModelContainerHandle container_handle,size_t * ret_num_inputs)223 AOTIRuntimeError AOTInductorModelContainerGetNumInputs(
224     AOTInductorModelContainerHandle container_handle,
225     size_t* ret_num_inputs) {
226   auto* container =
227       reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
228           container_handle);
229   CONVERT_EXCEPTION_TO_ERROR_CODE(
230       { *ret_num_inputs = container->num_inputs(); })
231 }
232 
AOTInductorModelContainerGetInputName(AOTInductorModelContainerHandle container_handle,size_t input_idx,const char ** ret_input_names)233 AOTIRuntimeError AOTInductorModelContainerGetInputName(
234     AOTInductorModelContainerHandle container_handle,
235     size_t input_idx,
236     const char** ret_input_names) {
237   auto* container =
238       reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
239           container_handle);
240   CONVERT_EXCEPTION_TO_ERROR_CODE(
241       { *ret_input_names = container->input_name(input_idx); })
242 }
243 
AOTInductorModelContainerGetNumOutputs(AOTInductorModelContainerHandle container_handle,size_t * ret_num_outputs)244 AOTIRuntimeError AOTInductorModelContainerGetNumOutputs(
245     AOTInductorModelContainerHandle container_handle,
246     size_t* ret_num_outputs) {
247   auto* container =
248       reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
249           container_handle);
250   CONVERT_EXCEPTION_TO_ERROR_CODE(
251       { *ret_num_outputs = container->num_outputs(); })
252 }
253 
AOTInductorModelContainerGetOutputName(AOTInductorModelContainerHandle container_handle,size_t output_idx,const char ** ret_output_names)254 AOTIRuntimeError AOTInductorModelContainerGetOutputName(
255     AOTInductorModelContainerHandle container_handle,
256     size_t output_idx,
257     const char** ret_output_names) {
258   auto* container =
259       reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
260           container_handle);
261   CONVERT_EXCEPTION_TO_ERROR_CODE(
262       { *ret_output_names = container->output_name(output_idx); })
263 }
264 
AOTInductorModelContainerGetCallSpec(AOTInductorModelContainerHandle container_handle,const char ** in_spec,const char ** out_spec)265 AOTIRuntimeError AOTInductorModelContainerGetCallSpec(
266     AOTInductorModelContainerHandle container_handle,
267     const char** in_spec,
268     const char** out_spec) {
269   auto* container =
270       reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
271           container_handle);
272   CONVERT_EXCEPTION_TO_ERROR_CODE({
273     *in_spec = container->get_in_spec();
274     *out_spec = container->get_out_spec();
275   })
276 }
277 
AOTInductorModelCreate(AOTInductorModelHandle * model_handle,AOTInductorConstantMapHandle constant_map_handle)278 AOTIRuntimeError AOTInductorModelCreate(
279     AOTInductorModelHandle* model_handle,
280     AOTInductorConstantMapHandle constant_map_handle){
281     CONVERT_EXCEPTION_TO_ERROR_CODE({
282       auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
283       auto constant_array = std::make_shared<std::vector<torch::aot_inductor::ConstantHandle>>();
284       auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
285 
286       auto model = new torch::aot_inductor::AOTInductorModel(
287           constant_map,
288           constant_array,
289           "cpu", // device_str is hardcoded, as AOTInductorModelCreate is only use for CPU models
290           ""
291       );
292 
293       if (input_map) {
294         for (auto const& kv : *input_map) {
295           constant_map->emplace(kv.first, kv.second);
296         }
297       } else {
298         model->load_constants();
299       }
300 
301       *model_handle = reinterpret_cast<AOTInductorModelHandle>(model);
302     })}
303 
AOTInductorModelRun(AOTInductorModelHandle model_handle,AtenTensorHandle * input_handles,AtenTensorHandle * output_handles)304 AOTIRuntimeError AOTInductorModelRun(
305     AOTInductorModelHandle model_handle,
306     AtenTensorHandle* input_handles,
307     AtenTensorHandle* output_handles) {
308   auto model =
309       reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
310   CONVERT_EXCEPTION_TO_ERROR_CODE({
311     AOTINoGradGuard guard;
312     model->run_impl(
313         input_handles,
314         output_handles,
315         (torch::aot_inductor::DeviceStreamType) nullptr,
316         nullptr);
317   })
318 }
319 
AOTInductorModelDelete(AOTInductorModelHandle model_handle)320 AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle){
321     CONVERT_EXCEPTION_TO_ERROR_CODE({
322       auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(
323           model_handle);
324       delete model;
325     })}
326 
AOTInductorModelGetNumOutputs(AOTInductorModelHandle model_handle,size_t * ret_num_outputs)327 AOTIRuntimeError AOTInductorModelGetNumOutputs(
328     AOTInductorModelHandle model_handle,
329     size_t* ret_num_outputs) {
330   CONVERT_EXCEPTION_TO_ERROR_CODE({
331       auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
332       *ret_num_outputs = model->num_outputs();
333   })
334 }
335 
AOTInductorModelUpdateConstantsMap(AOTInductorModelHandle model_handle,AOTInductorConstantMapHandle constant_map_handle)336 AOTIRuntimeError AOTInductorModelUpdateConstantsMap(
337     AOTInductorModelHandle model_handle,
338     AOTInductorConstantMapHandle constant_map_handle) {
339   auto model =
340       reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
341   CONVERT_EXCEPTION_TO_ERROR_CODE({
342     auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
343     auto input_map =
344         reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(
345             constant_map_handle);
346 
347     for (auto const& kv : *input_map) {
348       constant_map->emplace(kv.first, kv.second);
349     }
350     model->update_constants_map(std::move(constant_map));
351   })
352 }
353 
354 } // extern "C"
355