1 #include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
2 #include <torch/csrc/inductor/aoti_runtime/interface.h>
3 #include <torch/csrc/inductor/aoti_runtime/model_container.h>
4 #include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
5 #include <torch/csrc/inductor/aoti_runtime/thread_local.h>
6
7 #include <iostream>
8 #include <sstream>
9 #include <stdexcept>
10 #include <vector>
11
12 #define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \
13 try { \
14 __VA_ARGS__ \
15 } catch (const std::exception& e) { \
16 std::cerr << "Error: " << e.what() << std::endl; \
17 return AOTI_RUNTIME_FAILURE; \
18 } catch (...) { \
19 std::cerr << "Unknown exception occurred." << std::endl; \
20 return AOTI_RUNTIME_FAILURE; \
21 } \
22 return AOTI_RUNTIME_SUCCESS;
23
24 #define AOTI_VECTOR_SIZE_CHECK(actual_size, expected_size, name) \
25 do { \
26 AOTI_RUNTIME_CHECK( \
27 actual_size == expected_size, \
28 "expected " + std::string(name) + " vector size to be " + \
29 std::to_string(expected_size) + ", but got " + \
30 std::to_string(actual_size)); \
31 } while (0)
32
33 // AOTInductor uses at::addmm_out, which doesn't supports
34 // arguments that requires gradient. For this reason, we
35 // enforce no_grad context for run APIs.
36 //
37 // A RAII, thread local (!) guard that enables or disables grad mode upon
38 // construction, and sets it back to the original value upon destruction.
39 struct AOTINoGradGuard {
AOTINoGradGuardAOTINoGradGuard40 AOTINoGradGuard() : prev_mode(aoti_torch_grad_mode_is_enabled()) {
41 aoti_torch_grad_mode_set_enabled(false);
42 }
~AOTINoGradGuardAOTINoGradGuard43 ~AOTINoGradGuard() {
44 aoti_torch_grad_mode_set_enabled(prev_mode);
45 }
46 bool prev_mode;
47 };
48
49 extern "C" {
50
AOTInductorModelContainerCreate(AOTInductorModelContainerHandle * container_handle,size_t num_models,bool is_cpu,const char * cubin_dir)51 AOTIRuntimeError AOTInductorModelContainerCreate(
52 AOTInductorModelContainerHandle* container_handle,
53 size_t num_models,
54 bool is_cpu,
55 const char* cubin_dir) {
56 return AOTInductorModelContainerCreateWithDevice(
57 container_handle,
58 num_models,
59 is_cpu ? "cpu" : "cuda",
60 cubin_dir);
61 }
62
AOTInductorModelContainerCreateWithDevice(AOTInductorModelContainerHandle * container_handle,size_t num_models,const char * device_str,const char * cubin_dir)63 AOTIRuntimeError AOTInductorModelContainerCreateWithDevice(
64 AOTInductorModelContainerHandle* container_handle,
65 size_t num_models,
66 const char* device_str,
67 const char* cubin_dir) {
68 if (num_models == 0) {
69 std::cerr << "Error: num_models must be positive, but got 0" << std::endl;
70 return AOTI_RUNTIME_FAILURE;
71 }
72 CONVERT_EXCEPTION_TO_ERROR_CODE({
73 std::optional<std::string> cubin_dir_opt;
74 if (cubin_dir != nullptr) {
75 cubin_dir_opt.emplace(cubin_dir);
76 }
77 auto* container = new torch::aot_inductor::AOTInductorModelContainer(
78 num_models, std::string(device_str), cubin_dir_opt);
79 *container_handle =
80 reinterpret_cast<AOTInductorModelContainerHandle>(container);
81 })
82 }
83
AOTInductorModelContainerDelete(AOTInductorModelContainerHandle container_handle)84 AOTIRuntimeError AOTInductorModelContainerDelete(
85 AOTInductorModelContainerHandle container_handle) {
86 CONVERT_EXCEPTION_TO_ERROR_CODE({
87 auto* container =
88 reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
89 container_handle);
90 delete container;
91 });
92 }
93
AOTInductorModelContainerRun(AOTInductorModelContainerHandle container_handle,AtenTensorHandle * input_handles,size_t num_inputs,AtenTensorHandle * output_handles,size_t num_outputs,AOTInductorStreamHandle stream_handle,AOTIProxyExecutorHandle proxy_executor_handle)94 AOTIRuntimeError AOTInductorModelContainerRun(
95 AOTInductorModelContainerHandle container_handle,
96 AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles
97 // are stolen; the array itself is borrowed
98 size_t num_inputs,
99 AtenTensorHandle*
100 output_handles, // array for writing output AtenTensorHandle; handles
101 // will be stolen by the caller; the array itself is
102 // borrowed
103 size_t num_outputs,
104 AOTInductorStreamHandle stream_handle,
105 AOTIProxyExecutorHandle proxy_executor_handle) {
106 auto* container =
107 reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
108 container_handle);
109 AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs");
110 AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs");
111
112 auto stream =
113 reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
114 CONVERT_EXCEPTION_TO_ERROR_CODE({
115 AOTINoGradGuard guard;
116 container->run(
117 input_handles, output_handles, stream, proxy_executor_handle);
118 })
119 }
120
AOTInductorModelContainerGetNumConstants(AOTInductorModelContainerHandle container_handle,size_t * num_constants)121 AOTIRuntimeError AOTInductorModelContainerGetNumConstants(
122 AOTInductorModelContainerHandle container_handle,
123 size_t* num_constants) {
124 auto* container =
125 reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
126 container_handle);
127 CONVERT_EXCEPTION_TO_ERROR_CODE(
128 { *num_constants = container->num_constants(); })
129 }
130
AOTInductorModelContainerGetConstantName(AOTInductorModelContainerHandle container_handle,size_t idx,const char ** name)131 AOTIRuntimeError AOTInductorModelContainerGetConstantName(
132 AOTInductorModelContainerHandle container_handle,
133 size_t idx,
134 const char** name) {
135 auto* container =
136 reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
137 container_handle);
138 CONVERT_EXCEPTION_TO_ERROR_CODE(
139 { *name = container->constant_name(idx); })
140 }
141
AOTInductorModelContainerGetConstantOriginalFQN(AOTInductorModelContainerHandle container_handle,size_t idx,const char ** original_fqn)142 AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN(
143 AOTInductorModelContainerHandle container_handle,
144 size_t idx,
145 const char** original_fqn) {
146 auto* container =
147 reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
148 container_handle);
149 CONVERT_EXCEPTION_TO_ERROR_CODE(
150 { *original_fqn = container->constant_original_fqn(idx); })
151 }
152
AOTInductorModelContainerGetConstantFromFolded(AOTInductorModelContainerHandle container_handle,size_t idx,bool * from_folded)153 AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded(
154 AOTInductorModelContainerHandle container_handle,
155 size_t idx,
156 bool* from_folded) {
157 auto* container =
158 reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(container_handle);
159 CONVERT_EXCEPTION_TO_ERROR_CODE({ *from_folded = container->constant_from_folded(idx); })
160 }
161
AOTInductorModelContainerGetConstantDtype(AOTInductorModelContainerHandle container_handle,size_t idx,int32_t * dtype)162 AOTIRuntimeError AOTInductorModelContainerGetConstantDtype(
163 AOTInductorModelContainerHandle container_handle,
164 size_t idx,
165 int32_t* dtype) {
166 auto* container =
167 reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
168 container_handle);
169 CONVERT_EXCEPTION_TO_ERROR_CODE(
170 { *dtype = container->constant_dtype(idx); })
171 }
172
AOTInductorModelContainerUpdateConstantBuffer(AOTInductorModelContainerHandle container_handle,AOTInductorConstantMapHandle constant_map_handle,bool use_inactive,bool validate_full_update)173 AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer(
174 AOTInductorModelContainerHandle container_handle,
175 AOTInductorConstantMapHandle constant_map_handle,
176 bool use_inactive,
177 bool validate_full_update) {
178 auto* container =
179 reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
180 container_handle);
181 auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
182 CONVERT_EXCEPTION_TO_ERROR_CODE({
183 container->update_constant_buffer(
184 *input_map, use_inactive, validate_full_update);
185 })
186 }
187
AOTInductorModelContainerUpdateInactiveConstantBuffer(AOTInductorModelContainerHandle container_handle,AOTInductorConstantMapHandle constant_map_handle)188 AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer(
189 AOTInductorModelContainerHandle container_handle,
190 AOTInductorConstantMapHandle constant_map_handle) {
191 return AOTInductorModelContainerUpdateConstantBuffer(container_handle,
192 constant_map_handle,
193 /*use_inactive*/ true,
194 /*validate_full_update*/ true);
195 }
196
AOTInductorModelContainerRunConstantFolding(AOTInductorModelContainerHandle container_handle,bool use_inactive,AOTInductorStreamHandle stream_handle,AOTIProxyExecutorHandle proxy_executor_handle)197 AOTIRuntimeError AOTInductorModelContainerRunConstantFolding(
198 AOTInductorModelContainerHandle container_handle,
199 bool use_inactive,
200 AOTInductorStreamHandle stream_handle,
201 AOTIProxyExecutorHandle proxy_executor_handle) {
202 auto* container =
203 reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
204 container_handle);
205 auto stream =
206 reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
207 CONVERT_EXCEPTION_TO_ERROR_CODE({
208 AOTINoGradGuard guard;
209 container->run_const_fold(use_inactive, stream, proxy_executor_handle);
210 })
211 }
212
AOTInductorModelContainerSwapConstantBuffer(AOTInductorModelContainerHandle container_handle)213 AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer(
214 AOTInductorModelContainerHandle container_handle) {
215 auto* container =
216 reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
217 container_handle);
218 CONVERT_EXCEPTION_TO_ERROR_CODE({
219 container->swap_constant_buffer();
220 })
221 }
222
AOTInductorModelContainerGetNumInputs(AOTInductorModelContainerHandle container_handle,size_t * ret_num_inputs)223 AOTIRuntimeError AOTInductorModelContainerGetNumInputs(
224 AOTInductorModelContainerHandle container_handle,
225 size_t* ret_num_inputs) {
226 auto* container =
227 reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
228 container_handle);
229 CONVERT_EXCEPTION_TO_ERROR_CODE(
230 { *ret_num_inputs = container->num_inputs(); })
231 }
232
AOTInductorModelContainerGetInputName(AOTInductorModelContainerHandle container_handle,size_t input_idx,const char ** ret_input_names)233 AOTIRuntimeError AOTInductorModelContainerGetInputName(
234 AOTInductorModelContainerHandle container_handle,
235 size_t input_idx,
236 const char** ret_input_names) {
237 auto* container =
238 reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
239 container_handle);
240 CONVERT_EXCEPTION_TO_ERROR_CODE(
241 { *ret_input_names = container->input_name(input_idx); })
242 }
243
AOTInductorModelContainerGetNumOutputs(AOTInductorModelContainerHandle container_handle,size_t * ret_num_outputs)244 AOTIRuntimeError AOTInductorModelContainerGetNumOutputs(
245 AOTInductorModelContainerHandle container_handle,
246 size_t* ret_num_outputs) {
247 auto* container =
248 reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
249 container_handle);
250 CONVERT_EXCEPTION_TO_ERROR_CODE(
251 { *ret_num_outputs = container->num_outputs(); })
252 }
253
AOTInductorModelContainerGetOutputName(AOTInductorModelContainerHandle container_handle,size_t output_idx,const char ** ret_output_names)254 AOTIRuntimeError AOTInductorModelContainerGetOutputName(
255 AOTInductorModelContainerHandle container_handle,
256 size_t output_idx,
257 const char** ret_output_names) {
258 auto* container =
259 reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
260 container_handle);
261 CONVERT_EXCEPTION_TO_ERROR_CODE(
262 { *ret_output_names = container->output_name(output_idx); })
263 }
264
AOTInductorModelContainerGetCallSpec(AOTInductorModelContainerHandle container_handle,const char ** in_spec,const char ** out_spec)265 AOTIRuntimeError AOTInductorModelContainerGetCallSpec(
266 AOTInductorModelContainerHandle container_handle,
267 const char** in_spec,
268 const char** out_spec) {
269 auto* container =
270 reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
271 container_handle);
272 CONVERT_EXCEPTION_TO_ERROR_CODE({
273 *in_spec = container->get_in_spec();
274 *out_spec = container->get_out_spec();
275 })
276 }
277
AOTInductorModelCreate(AOTInductorModelHandle * model_handle,AOTInductorConstantMapHandle constant_map_handle)278 AOTIRuntimeError AOTInductorModelCreate(
279 AOTInductorModelHandle* model_handle,
280 AOTInductorConstantMapHandle constant_map_handle){
281 CONVERT_EXCEPTION_TO_ERROR_CODE({
282 auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
283 auto constant_array = std::make_shared<std::vector<torch::aot_inductor::ConstantHandle>>();
284 auto input_map = reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(constant_map_handle);
285
286 auto model = new torch::aot_inductor::AOTInductorModel(
287 constant_map,
288 constant_array,
289 "cpu", // device_str is hardcoded, as AOTInductorModelCreate is only use for CPU models
290 ""
291 );
292
293 if (input_map) {
294 for (auto const& kv : *input_map) {
295 constant_map->emplace(kv.first, kv.second);
296 }
297 } else {
298 model->load_constants();
299 }
300
301 *model_handle = reinterpret_cast<AOTInductorModelHandle>(model);
302 })}
303
AOTInductorModelRun(AOTInductorModelHandle model_handle,AtenTensorHandle * input_handles,AtenTensorHandle * output_handles)304 AOTIRuntimeError AOTInductorModelRun(
305 AOTInductorModelHandle model_handle,
306 AtenTensorHandle* input_handles,
307 AtenTensorHandle* output_handles) {
308 auto model =
309 reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
310 CONVERT_EXCEPTION_TO_ERROR_CODE({
311 AOTINoGradGuard guard;
312 model->run_impl(
313 input_handles,
314 output_handles,
315 (torch::aot_inductor::DeviceStreamType) nullptr,
316 nullptr);
317 })
318 }
319
AOTInductorModelDelete(AOTInductorModelHandle model_handle)320 AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle){
321 CONVERT_EXCEPTION_TO_ERROR_CODE({
322 auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(
323 model_handle);
324 delete model;
325 })}
326
AOTInductorModelGetNumOutputs(AOTInductorModelHandle model_handle,size_t * ret_num_outputs)327 AOTIRuntimeError AOTInductorModelGetNumOutputs(
328 AOTInductorModelHandle model_handle,
329 size_t* ret_num_outputs) {
330 CONVERT_EXCEPTION_TO_ERROR_CODE({
331 auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
332 *ret_num_outputs = model->num_outputs();
333 })
334 }
335
AOTInductorModelUpdateConstantsMap(AOTInductorModelHandle model_handle,AOTInductorConstantMapHandle constant_map_handle)336 AOTIRuntimeError AOTInductorModelUpdateConstantsMap(
337 AOTInductorModelHandle model_handle,
338 AOTInductorConstantMapHandle constant_map_handle) {
339 auto model =
340 reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
341 CONVERT_EXCEPTION_TO_ERROR_CODE({
342 auto constant_map = std::make_shared<torch::aot_inductor::ConstantMap>();
343 auto input_map =
344 reinterpret_cast<std::unordered_map<std::string, AtenTensorHandle>*>(
345 constant_map_handle);
346
347 for (auto const& kv : *input_map) {
348 constant_map->emplace(kv.first, kv.second);
349 }
350 model->update_constants_map(std::move(constant_map));
351 })
352 }
353
354 } // extern "C"
355