1 #include <pybind11/pytypes.h>
2 #include <torch/csrc/utils/pybind.h>
3 #include <torch/csrc/utils/python_arg_parser.h>
4 #include <torch/csrc/utils/schema_info.h>
5
6 #include <ATen/core/operator_name.h>
7 #include <torch/csrc/jit/api/module.h>
8 #include <torch/csrc/jit/backends/backend_init.h>
9 #include <torch/csrc/jit/codegen/cuda/interface.h>
10 // #include <torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.h>
11 #include <torch/csrc/jit/codegen/fuser/interface.h>
12 #include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
13 #if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
14 #include <torch/csrc/jit/codegen/onednn/interface.h>
15 #endif
16 #include <c10/core/SymNodeImpl.h>
17 #include <torch/csrc/jit/frontend/ir_emitter.h>
18 #include <torch/csrc/jit/frontend/tracer.h>
19 #include <torch/csrc/jit/ir/irparser.h>
20 #include <torch/csrc/jit/jit_log.h>
21 #include <torch/csrc/jit/passes/autocast.h>
22 #include <torch/csrc/jit/passes/batch_mm.h>
23 #include <torch/csrc/jit/passes/canonicalize.h>
24 #include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
25 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
26 #include <torch/csrc/jit/passes/constant_pooling.h>
27 #include <torch/csrc/jit/passes/constant_propagation.h>
28 #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
29 #include <torch/csrc/jit/passes/create_functional_graphs.h>
30 #include <torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.h>
31 #include <torch/csrc/jit/passes/dead_code_elimination.h>
32 #include <torch/csrc/jit/passes/decompose_ops.h>
33 #include <torch/csrc/jit/passes/device_type_analysis.h>
34 #include <torch/csrc/jit/passes/dtype_analysis.h>
35 #include <torch/csrc/jit/passes/erase_number_types.h>
36 #include <torch/csrc/jit/passes/fold_conv_bn.h>
37 #include <torch/csrc/jit/passes/freeze_module.h>
38 #include <torch/csrc/jit/passes/frozen_concat_linear.h>
39 #include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
40 #include <torch/csrc/jit/passes/frozen_conv_folding.h>
41 #include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
42 #include <torch/csrc/jit/passes/frozen_linear_folding.h>
43 #include <torch/csrc/jit/passes/frozen_linear_transpose.h>
44 #include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
45 #include <torch/csrc/jit/passes/fuse_linear.h>
46 #include <torch/csrc/jit/passes/fuse_relu.h>
47 #include <torch/csrc/jit/passes/graph_fuser.h>
48 #include <torch/csrc/jit/passes/inline_fork_wait.h>
49 #include <torch/csrc/jit/passes/inliner.h>
50 #include <torch/csrc/jit/passes/integer_value_refinement.h>
51 #include <torch/csrc/jit/passes/loop_unrolling.h>
52 #include <torch/csrc/jit/passes/lower_graph.h>
53 #include <torch/csrc/jit/passes/lower_tuples.h>
54 #include <torch/csrc/jit/passes/metal_rewrite.h>
55 #include <torch/csrc/jit/passes/mobile_optimizer_type.h>
56 #include <torch/csrc/jit/passes/normalize_ops.h>
57 #include <torch/csrc/jit/passes/peephole.h>
58 #include <torch/csrc/jit/passes/peephole_list_idioms.h>
59 #include <torch/csrc/jit/passes/quantization/dedup_module_uses.h>
60 #include <torch/csrc/jit/passes/quantization/finalize.h>
61 #include <torch/csrc/jit/passes/quantization/fusion_passes.h>
62 #include <torch/csrc/jit/passes/quantization/insert_observers.h>
63 #include <torch/csrc/jit/passes/quantization/insert_quant_dequant.h>
64 #include <torch/csrc/jit/passes/quantization/quantization_type.h>
65 #include <torch/csrc/jit/passes/refine_tuple_types.h>
66 #include <torch/csrc/jit/passes/remove_dropout.h>
67 #include <torch/csrc/jit/passes/remove_expands.h>
68 #include <torch/csrc/jit/passes/remove_inplace_ops.h>
69 #include <torch/csrc/jit/passes/remove_mutation.h>
70 #include <torch/csrc/jit/passes/replacement_of_old_operators.h>
71 #include <torch/csrc/jit/passes/restore_mutation.h>
72 #include <torch/csrc/jit/passes/shape_analysis.h>
73 #include <torch/csrc/jit/passes/specialize_autogradzero.h>
74 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
75 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
76 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
77 #include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
78 #include <torch/csrc/jit/passes/vulkan_rewrite.h>
79 #include <torch/csrc/jit/passes/xnnpack_rewrite.h>
80 #include <torch/csrc/jit/python/pybind_utils.h>
81 #include <torch/csrc/jit/python/python_arg_flatten.h>
82 #include <torch/csrc/jit/python/python_custom_class.h>
83 #include <torch/csrc/jit/python/python_ir.h>
84 #include <torch/csrc/jit/python/python_tracer.h>
85 #include <torch/csrc/jit/python/python_tree_views.h>
86 #include <torch/csrc/jit/python/script_init.h>
87 #include <torch/csrc/jit/python/utf8_decoding_ignore.h>
88 #include <torch/csrc/jit/runtime/argument_spec.h>
89 #include <torch/csrc/jit/runtime/autodiff.h>
90 #include <torch/csrc/jit/runtime/decomposition_registry.h>
91 #include <torch/csrc/jit/runtime/graph_executor.h>
92 #include <torch/csrc/jit/runtime/jit_exception.h>
93 #include <torch/csrc/jit/runtime/jit_trace.h>
94 #include <torch/csrc/jit/runtime/operator.h>
95 #include <torch/csrc/jit/runtime/print_handler.h>
96 #include <torch/csrc/jit/runtime/static/init.h>
97 #include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
98 #include <torch/csrc/jit/serialization/export.h>
99 #include <torch/csrc/jit/serialization/import.h>
100 #include <torch/csrc/jit/tensorexpr/kernel.h>
101 #include <torch/csrc/jit/tensorexpr/tensorexpr_init.h>
102 #include <torch/csrc/utils/cpp_stacktraces.h>
103
104 #include <c10/macros/Export.h>
105 #include <c10/util/irange.h>
106 #include <c10/util/signal_handler.h>
107 #include <caffe2/serialize/inline_container.h>
108
109 #include <pybind11/cast.h>
110 #include <pybind11/functional.h>
111 #include <pybind11/iostream.h>
112 #include <pybind11/operators.h>
113
114 #include <torch/csrc/jit/runtime/profiling_graph_executor_impl.h>
115 #include <memory>
116 #include <sstream>
117 #include <stdexcept>
118 #include <string>
119 #include <tuple>
120 #include <utility>
121
122 namespace torch::jit {
123
124 using c10::AliasInfo;
125 using c10::Argument;
126 using c10::FunctionSchema;
127 using c10::SchemaArgType;
128 using c10::SchemaArgument;
129 using c10::SymNode;
130 using caffe2::serialize::PyTorchStreamReader;
131 using caffe2::serialize::PyTorchStreamWriter;
132 using torch::utils::SchemaInfo;
133
134 namespace {
135
136 using autograd::variable_list;
137
loadPythonClasses()138 bool loadPythonClasses() {
139 // Leaving this code here, because it will likely be useful at some point
140 // PyObject *jit_module = PyImport_ImportModule("torch.jit");
141 // THPUtils_assert(jit_module, "class loader couldn't access "
142 //"torch.jit module");
143 // PyObject *jit_dict = PyModule_GetDict(jit_module);
144
145 return true;
146 }
147
opAllowsNumbersAsTensors(c10::Symbol symbol)148 static bool opAllowsNumbersAsTensors(c10::Symbol symbol) {
149 return symbol.is_prims() || symbol.is_nvprims() ||
150 (symbol.is_aten() &&
151 torch::should_allow_numbers_as_tensors(symbol.toUnqualString()));
152 }
153
toTypeInferredIValueOptional(py::handle input)154 std::optional<IValue> toTypeInferredIValueOptional(py::handle input) {
155 // Errors need to be caught here because toTypeInferredIValue errors out
156 // on various object types, but we want it to work with all types.
157 try {
158 return toTypeInferredIValue(input);
159 } catch (const c10::Error& e) {
160 return std::nullopt;
161 }
162 }
163 } // anonymous namespace
164
165 #if !defined(USE_ROCM)
166 TORCH_API void runJITCPPTests();
167 #endif
168
initJITBindings(PyObject * module)169 void initJITBindings(PyObject* module) {
170 auto m = py::handle(module).cast<py::module>();
171 auto jit = m.def_submodule("_jit");
172
173 // This is a static object, so we must leak the Python object
174 // "release()" is used here to preserve 1 refcount on the
175 // object, preventing it from ever being de-allocated by CPython.
176 static py::handle exc =
177 py::exception<JITException>(m, "JITException").release();
178
179 py::register_exception_translator([](std::exception_ptr p) {
180 try {
181 if (p) {
182 std::rethrow_exception(p);
183 }
184 } catch (const JITException& e) {
185 // special handling of JITException, to set its python class name and msg
186 py::gil_scoped_acquire acquire;
187 const auto& className = e.getPythonClassName();
188 const auto& originalMsg = e.getOriginalMsg();
189 JITException::setCaughtOriginalMsg(originalMsg.value_or(""));
190 JITException::setCaughtPythonClassName(className.value_or(""));
191 // If we still had the py::exception<JITException> object, we could
192 // just call it. But we must get a handle to leak it and there is no
193 // way I can find to re-create it from the handle. So setting the
194 // exception manually
195 PyErr_SetString(exc.ptr(), e.what());
196 }
197 });
198
199 m.def(
200 "_get_caught_jit_exception_class_name",
201 JITException::getCaughtPythonClassName);
202 m.def(
203 "_get_caught_jit_exception_original_msg",
204 JITException::getCaughtOriginalMsg);
205
206 py::class_<python::IODescriptor> iodescriptor(
207 m,
208 "IODescriptor"); // NOLINT(bugprone-unused-raii)
209
210 m.def("_jit_init", loadPythonClasses)
211 .def(
212 "_jit_debug_fuser_num_cached_kernel_specs",
213 torch::jit::fuser::debugNumCachedKernelSpecs)
214 .def("_jit_pass_lower_all_tuples", LowerAllTuples)
215 .def(
216 "_new_symbolic_shape_symbol",
217 []() { return c10::ShapeSymbol::newSymbol().value(); })
218 .def(
219 "_jit_shape_compute_graph_for_node",
220 [](Node* n) -> std::optional<std::shared_ptr<Graph>> {
221 if (!n->maybeSchema()) {
222 return std::nullopt;
223 }
224 return shapeComputeGraphForSchema(n->schema());
225 })
226 .def(
227 "_jit_decomposition_graph_for_node",
228 [](Node* n) -> std::optional<std::shared_ptr<Graph>> {
229 if (!n->maybeSchema()) {
230 return std::nullopt;
231 }
232 return GetDecomposition(n->schema());
233 })
234 .def("_jit_pass_run_decompositions", RunDecompositions)
235 // using Node* here instead of Schema because looking up the schema
236 // and passing it in from Python will have a different pointer than the
237 // schema that is globally used for caching
238 .def(
239 "_jit_register_shape_compute_graph_for_node",
240 [](Node* n, std::shared_ptr<Graph>& graph) {
241 if (n->maybeSchema()) {
242 const FunctionSchema& schema = n->schema();
243 RegisterShapeComputeGraphForSchema(schema, graph);
244 } else {
245 TORCH_INTERNAL_ASSERT(false, "Expected schema", n);
246 }
247 })
248 .def(
249 "_jit_register_decomposition_for_schema",
250 [](const FunctionSchema& s, std::shared_ptr<Graph>& graph) {
251 // because this is invoked by python, the function schema *
252 // becomes different, and we need to find and reuse the
253 // one that is used for caching
254 auto op =
255 findOperatorFor(c10::OperatorName(s.name(), s.overload_name()));
256 RegisterDecomposition(op->schema(), graph);
257 })
258 .def("_jit_pass_propagate_shapes_on_graph", PropagateShapesOnGraph)
259 .def(
260 "_jit_pass_propagate_shapes_on_graph_and_build_compute",
261 [](std::shared_ptr<Graph>& graph) {
262 return PropagateShapesAndBuildLargeShapeComputeGraph(
263 graph, *graph->nodes().begin(), *graph->nodes().end());
264 })
265 .def(
266 "_jit_pass_propagate_shapes_on_graph_and_build_compute",
267 [](std::shared_ptr<Graph>& graph, Node* beg) {
268 return PropagateShapesAndBuildLargeShapeComputeGraph(
269 graph, beg, *graph->nodes().end());
270 })
271 .def(
272 "_jit_pass_propagate_shapes_on_graph_and_build_compute",
273 PropagateShapesAndBuildLargeShapeComputeGraph)
274 .def("_jit_pass_integer_value_refinement", RefineIntegerValues)
275 .def(
276 "_jit_set_symbolic_shapes_test_mode",
277 &setSymbolicShapeAnalysisTestMode)
278 .def(
279 "_jit_symbolic_shapes_test_mode_enabled",
280 &symbolicShapeAnalysisTestModeEnabled)
281 .def("_jit_pass_autocast", Autocast)
282 .def("_jit_set_autocast_mode", &setAutocastMode)
283 .def("_jit_pass_fuse", FuseGraph)
284 .def(
285 "_jit_pass_replace_old_ops_with_upgraders",
286 [](std::shared_ptr<Graph>& g) {
287 return ReplaceOldOperatorsWithUpgraders(g);
288 })
289 .def(
290 "_jit_pass_dce",
291 [](std::shared_ptr<Graph>& g) {
292 return EliminateDeadCode(g->block()); // overload resolution
293 })
294 .def(
295 "_jit_pass_dce_allow_deleting_nodes_with_side_effects",
296 [](std::shared_ptr<Graph>& g) {
297 return EliminateDeadCode(
298 g->block(),
299 true,
300 DCESideEffectPolicy::
301 ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); // overload
302 // resolution
303 })
304 .def(
305 "_jit_pass_cse",
306 [](std::shared_ptr<Graph>& g) {
307 return EliminateCommonSubexpression(g); // overload resolution
308 })
309 .def(
310 "_jit_pass_fuse_quantized_add_relu",
311 [](std::shared_ptr<Graph>& g) {
312 return FuseQuantizedAddRelu(g); // overload resolution
313 })
314 .def(
315 "_jit_pass_insert_observers",
316 [](Module& module,
317 const std::string& method_name,
318 const py::dict& qconfig_dict,
319 bool inplace,
320 int quant_type_int) {
321 auto dict = py::cast<std::unordered_map<
322 std::string,
323 std::optional<std::tuple<Module, Module>>>>(qconfig_dict);
324 auto quant_type = static_cast<QuantType>(quant_type_int);
325 return InsertObservers(
326 module, method_name, dict, inplace, quant_type);
327 },
328 py::arg("module"),
329 py::arg("method_name"),
330 py::arg("qconfig_dict"),
331 py::arg("inplace"),
332 py::arg("quant_type_int") = 1)
333 .def(
334 "_jit_pass_insert_observer_method_for_ondevice_ptq",
335 [](Module& module,
336 const std::string& method_name,
337 const py::dict& qconfig_dict,
338 bool inplace,
339 int quant_type_int) {
340 auto dict = py::cast<std::unordered_map<
341 std::string,
342 std::optional<std::tuple<Module, Module>>>>(qconfig_dict);
343 auto quant_type = static_cast<QuantType>(quant_type_int);
344 return InsertObserversForOnDevicePTQ(
345 module, method_name, dict, inplace, quant_type);
346 },
347 py::arg("module"),
348 py::arg("method_name"),
349 py::arg("qconfig_dict"),
350 py::arg("inplace"),
351 py::arg("quant_type_int") = 1)
352 .def(
353 "_jit_pass_insert_quant_dequant",
354 [](Module& module,
355 const std::string& method_name,
356 bool inplace,
357 bool debug,
358 int quant_type_int) {
359 auto quant_type = static_cast<QuantType>(quant_type_int);
360 return InsertQuantDeQuant(
361 module, method_name, inplace, debug, quant_type);
362 },
363 py::arg("module"),
364 py::arg("method_name"),
365 py::arg("inplace"),
366 py::arg("debug"),
367 py::arg("quant_type_int") = 1)
368 .def(
369 "_jit_pass_insert_quant_dequant_for_ondevice_ptq",
370 [](Module& module,
371 const std::string& method_name,
372 bool inplace,
373 bool debug,
374 int quant_type_int) {
375 auto quant_type = static_cast<QuantType>(quant_type_int);
376 return InsertQuantDeQuantOnDevicePTQ(
377 module, method_name, inplace, debug, quant_type);
378 },
379 py::arg("module"),
380 py::arg("method_name"),
381 py::arg("inplace"),
382 py::arg("debug"),
383 py::arg("quant_type_int") = 1)
384 .def(
385 "_jit_pass_insert_prepack_unpack",
386 [](std::shared_ptr<Graph>& g) { return InsertPrepackUnpack(g); })
387 .def(
388 "_jit_pass_insert_prepack_unpack",
389 [](Module& module) { return InsertPrepackUnpack(module); })
390 .def(
391 "_jit_pass_quant_fusion",
392 [](std::shared_ptr<Graph>& g) { return QuantFusion(g); })
393 .def(
394 "_jit_pass_fold_convbn",
395 [](Module& module) { return FoldConvBatchNorm(module); })
396 .def(
397 "_jit_pass_dbr_quant_remove_redundant_aliases",
398 [](Module& module) { return DBRQuantRemoveRedundantAliases(module); })
399 .def(
400 "_freeze_module",
401 [](Module& module,
402 std::vector<std::string>& preservedAttrs,
403 bool freezeInterfaces,
404 bool preserveParameters) {
405 return freeze_module(
406 module, preservedAttrs, freezeInterfaces, preserveParameters);
407 },
408 py::arg("module"),
409 py::arg("preservedAttrs") = std::vector<std::string>(),
410 py::arg("freezeInterfaces") = true,
411 py::arg("preserveParameters") = false)
412 .def("_jit_pass_concat_frozen_linear", &FrozenConcatLinear)
413 .def("_jit_pass_fold_frozen_conv_bn", &FoldFrozenConvBatchnorm)
414 .def("_jit_pass_fold_frozen_conv_add_or_sub", &FoldFrozenConvAddOrSub)
415 .def("_jit_pass_fold_frozen_conv_mul_or_div", &FoldFrozenConvMulOrDiv)
416 .def("_jit_pass_fold_frozen_linear_bn", &FoldFrozenLinearBatchnorm)
417 .def("_jit_pass_convert_frozen_ops_to_mkldnn", &ConvertFrozenOpsToMKLDNN)
418 .def("_jit_pass_fuse_frozen_conv_add_relu", &FuseFrozenConvAddRelu)
419 .def("_jit_pass_transpose_frozen_linear", &FrozenLinearTranspose)
420 .def("_jit_pass_optimize_frozen_graph", &OptimizeFrozenGraph)
421 .def(
422 "_jit_pass_optimize_for_inference",
423 [](Module& module, const std::vector<std::string>& other_methods) {
424 optimize_for_inference(module, other_methods);
425 },
426 py::arg("module"),
427 py::arg("other_methods") = std::vector<std::string>())
428 .def("_jit_pass_fuse_linear", &FuseLinear)
429 .def(
430 "_jit_pass_fuse_add_relu",
431 [](std::shared_ptr<Graph>& graph) { FuseAddRelu(graph); })
432 .def("_jit_pass_dedup_module_uses", &DedupModuleUses)
433 .def("_jit_pass_replicate_dequantize", &ReplicateDeQuant)
434 .def(
435 "_jit_pass_swap_functional_linear",
436 [](std::shared_ptr<Graph>& graph) { SwapFunctionalLinear(graph); })
437 .def(
438 "_jit_pass_swap_functional_linear",
439 [](Module& module) { SwapFunctionalLinear(module); })
440 .def(
441 "_jit_pass_quant_finalize",
442 [](Module& module,
443 int quant_type_int,
444 const std::vector<std::string>& preserved_attrs) {
445 auto quant_type = static_cast<QuantType>(quant_type_int);
446 return Finalize(module, quant_type, preserved_attrs);
447 },
448 py::arg("module"),
449 py::arg("quant_type_int") = 1,
450 py::arg("preserved_attrs") = std::vector<std::string>())
451 .def(
452 "_jit_pass_quant_finalize_for_ondevice_ptq",
453 [](Module& module,
454 int quant_type_int,
455 const std::string& method_name) {
456 auto quant_type = static_cast<QuantType>(quant_type_int);
457 return FinalizeOnDevicePTQ(module, quant_type, method_name);
458 },
459 py::arg("module"),
460 py::arg("quant_type_int") = 1,
461 py::arg("preserved_attrs") = std::vector<std::string>())
462 .def(
463 "_jit_pass_pattern_based_rewrite",
464 [](const Module& m) { return PatternBasedRewrite(m); })
465 .def(
466 "_jit_pass_custom_pattern_based_rewrite",
467 [](const std::string& pattern,
468 const std::string& fused_node_name,
469 const Module& m) {
470 SubgraphRewriter subgraph_rewriter;
471 subgraph_rewriter.RegisterRewritePattern(pattern, fused_node_name);
472 subgraph_rewriter.runOnModule(m);
473 })
474 .def(
475 "_jit_pass_custom_pattern_based_rewrite_graph",
476 [](const std::string& pattern,
477 const std::string& fused_node_name,
478 std::shared_ptr<Graph> g,
479 const std::vector<std::pair<std::string, std::string>>&
480 value_name_pairs) {
481 SubgraphRewriter subgraph_rewriter;
482 subgraph_rewriter.RegisterRewritePattern(
483 pattern, fused_node_name, value_name_pairs);
484 subgraph_rewriter.runOnGraph(g);
485 },
486 py::arg("pattern"),
487 py::arg("fused_node_name"),
488 py::arg("g"),
489 py::arg("value_name_pairs") =
490 std::vector<std::pair<std::string, std::string>>())
491 .def("_jit_pass_constant_pooling", ConstantPooling)
492 // RemoveInplaceOps is used by CoreML so it must be removed with care.
493 .def("_jit_pass_propagate_dtype", DtypePropagation)
494 .def("_jit_pass_propagate_device", DeviceTypePropagation)
495 .def(
496 "_jit_pass_remove_inplace_ops",
497 [](const std::shared_ptr<Graph>& g) { return RemoveInplaceOps(g); })
498 .def(
499 "_jit_pass_create_functional_graphs",
500 [](std::shared_ptr<Graph>& g) { return CreateFunctionalGraphs(g); })
501 .def(
502 "_jit_pass_remove_mutation",
503 [](std::shared_ptr<Graph>& g) {
504 RemoveListMutation(g);
505 return RemoveTensorMutation(g);
506 })
507 .def(
508 "_jit_pass_functional_to_inplace_activation",
509 [](std::shared_ptr<Graph>& g) {
510 return FunctionalToInplaceActivation(g);
511 })
512 .def(
513 "_jit_pass_inplace_to_functional_activation",
514 [](std::shared_ptr<Graph>& g) {
515 return InplaceToFunctionalActivation(g);
516 })
517 .def(
518 "_jit_pass_inline_functional_graphs",
519 [](std::shared_ptr<Graph>& g) { return InlineFunctionalGraphs(g); })
520 .def(
521 "_jit_pass_peephole",
522 [](const std::shared_ptr<Graph>& g, bool disable_shape_peepholes) {
523 return PeepholeOptimize(g, disable_shape_peepholes);
524 },
525 py::arg("graph"),
526 py::arg("disable_shape_peepholes") = false)
527 .def(
528 "_jit_pass_peephole_list_idioms",
529 [](const std::shared_ptr<Graph>& g, bool refine_list_len) {
530 return PeepholeOptimizeListIdioms(g, refine_list_len);
531 },
532 py::arg("graph"),
533 py::arg("refine_list_len") = false)
534 .def(
535 "_jit_pass_refine_integer_values",
536 [](std::shared_ptr<Graph>& g) { return RefineIntegerValues(g); })
537 .def(
538 "_jit_pass_fuse_addmm",
539 [](std::shared_ptr<Graph>& g) { return FuseAddMM(g); })
540 .def(
541 "_jit_pass_canonicalize",
542 [](const std::shared_ptr<Graph>& g, bool keep_unique_names = true) {
543 return Canonicalize(g, keep_unique_names);
544 },
545 py::arg("graph"),
546 py::arg("keep_unique_names") = true)
547 .def("_jit_pass_lint", LintGraph)
548 .def(
549 "_jit_pass_complete_shape_analysis",
550 [](const std::shared_ptr<Graph>& graph,
551 const py::tuple& inputs,
552 bool with_grad) {
553 ArgumentSpecCreator arg_spec_creator(*graph);
554 Stack stack;
555 stack.reserve(inputs.size()); // captures?
556 for (auto& obj : inputs) {
557 stack.push_back(toTypeInferredIValue(obj));
558 }
559 ArgumentSpec spec = arg_spec_creator.create(with_grad, stack);
560 arg_spec_creator.specializeTypes(*graph, spec);
561 // We only get partial specialization from the arg_spec_creator, but
562 // we want full shape specialization. The alternative would be to
563 // have a "complete type inference" function in ArguemntSpecCreator.
564 auto g_inputs = graph->inputs();
565 for (const auto i : c10::irange(inputs.size())) {
566 if (stack[i].isTensor()) {
567 g_inputs[i]->setType(stack[i].type());
568 }
569 }
570 PropagateInputShapes(graph);
571 })
572 .def(
573 "_jit_interpret_graph",
574 [](std::shared_ptr<Graph>& graph, const py::tuple& inputs) {
575 Stack stack;
576 stack.reserve(inputs.size()); // captures?
577 for (auto& obj : inputs) {
578 stack.push_back(toTypeInferredIValue(obj));
579 }
580 auto g_inputs = graph->inputs();
581 for (const auto i : c10::irange(inputs.size())) {
582 if (stack[i].isTensor()) {
583 g_inputs[i]->setType(stack[i].type());
584 }
585 }
586 Code code(graph, "<on-demand-func>");
587 InterpreterState(code).run(stack);
588 return createPyObjectForStack(std::move(stack));
589 },
590 py::doc(
591 "Interpret a JIT graph with given inputs without running any optimization passes on it"))
592 .def(
593 "_jit_trace_graph",
594 [](std::shared_ptr<Graph>& graph, const py::tuple& inputs) {
595 Stack stack;
596 stack.reserve(inputs.size()); // captures?
597 for (auto& obj : inputs) {
598 stack.push_back(toTypeInferredIValue(obj));
599 }
600 auto g_inputs = graph->inputs();
601 for (const auto i : c10::irange(inputs.size())) {
602 if (stack[i].isTensor()) {
603 g_inputs[i]->setType(stack[i].type());
604 }
605 }
606 return TraceGraph(graph, stack);
607 })
608 .def(
609 "_jit_trace_module",
610 [](Module& model, const py::tuple& inputs) {
611 auto graph = model.get_method("forward").graph();
612 Stack stack;
613 stack.reserve(inputs.size() + 1); // captures?
614 push(stack, model._ivalue());
615 for (auto& obj : inputs) {
616 stack.push_back(toTypeInferredIValue(obj));
617 }
618 auto traced = TraceGraph(graph, stack);
619 GRAPH_DUMP("Traced Graph", traced);
620
621 // the easiest way to replace a graph in a module is
622 // to remove all the nodes in the original graph
623 // clone everything from the traced one
624 graph->block()->clear();
625 graph->block()->cloneFrom(traced->block(), nullptr);
626 GRAPH_DUMP("Copied Graph", graph);
627 })
628 .def("_jit_pass_remove_expands", RemoveExpands)
629 .def("_jit_pass_erase_number_types", EraseNumberTypes)
630 .def("_jit_pass_inline_fork_wait", InlineForkWait)
631 .def("_jit_pass_inline", Inline)
632 .def(
633 "_jit_pass_lower_graph",
634 [](std::shared_ptr<Graph>& graph, const Module& self) {
635 return LowerGraph(*graph, self._ivalue());
636 })
637 .def("_jit_pass_loop_unrolling", UnrollLoops)
638 .def("_jit_pass_constant_loop_unrolling", UnrollConstantLoops)
639 .def(
640 "_jit_pass_constant_propagation_immutable_types",
641 [](std::shared_ptr<Graph>& g) {
642 return ConstantPropagationImmutableTypes(g);
643 })
644 .def(
645 "_jit_pass_constant_propagation",
646 [](std::shared_ptr<Graph>& g) { return ConstantPropagation(g); },
647 py::arg("graph"))
648 .def("_jit_pass_erase_shape_information", EraseShapeInformation)
649 .def(
650 "_jit_object_is_non_holding",
651 [](Node& n) {
652 return toIValue(n.output())->toObject()->is_weak_compilation_ref();
653 })
654 .def(
655 "_jit_erase_non_input_shape_information",
656 [](std::shared_ptr<Graph>& g) {
657 std::vector<TypePtr> input_types;
658 for (Value* v : g->inputs()) {
659 if (auto tt = v->type()->cast<TensorType>()) {
660 input_types.emplace_back(tt);
661 } else {
662 input_types.emplace_back(nullptr);
663 }
664 }
665 EraseShapeInformation(g);
666 for (size_t i = 0; i < input_types.size(); ++i) {
667 if (input_types[i]) {
668 g->inputs().at(i)->setType(input_types[i]);
669 }
670 }
671 })
672 .def(
673 "_jit_pass_create_autodiff_subgraphs",
674 [](const std::shared_ptr<Graph>& graph, const py::object& threshold) {
675 if (threshold.is_none()) {
676 CreateAutodiffSubgraphs(graph);
677 } else {
678 CreateAutodiffSubgraphs(graph, py::cast<int>(threshold));
679 }
680 },
681 py::arg("graph"),
682 py::arg("threshold") = py::none())
683 #if defined(BUILDING_TESTS) && !defined(USE_ROCM)
684 .def(
685 "_jit_run_cpp_tests",
686 []() {
687 // We have to release the GIL inside this method, because if we
688 // happen to initialize the autograd engine in these tests, the
689 // newly spawned worker threads will try to initialize their
690 // PyThreadState*, and they need the GIL for this.
691 pybind11::gil_scoped_release _no_gil;
692 return runJITCPPTests();
693 })
694 .def("_jit_has_cpp_tests", []() { return true; })
695 .def("_has_tensorexpr_cpp_tests", []() { return true; })
696 #else
697 .def("_jit_run_cpp_tests", []() { throw std::exception(); })
698 .def("_jit_has_cpp_tests", []() { return false; })
699 .def("_run_tensorexpr_cpp_tests", []() { throw std::exception(); })
700 .def("_has_tensorexpr_cpp_tests", []() { return false; })
701 #endif
702 .def(
703 "_jit_flatten",
704 [](py::handle& obj) {
705 auto res = python::flatten(obj);
706 return std::make_pair(res.vars, res.desc);
707 })
708 .def(
709 "_jit_unflatten",
710 [](const autograd::variable_list& vars, python::IODescriptor& desc) {
711 return py::reinterpret_steal<py::object>(
712 python::unflatten(vars, desc));
713 })
714 .def("_jit_pass_canonicalize_graph_fuser_ops", CanonicalizeOps)
715 .def("_jit_pass_decompose_ops", DecomposeOps)
716 .def("_jit_pass_specialize_autogradzero", specializeAutogradZero)
717 .def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
718 .def("_jit_override_can_fuse_on_gpu", &overrideCanFuseOnGPU)
719 .def("_jit_can_fuse_on_cpu", &canFuseOnCPU)
720 .def("_jit_can_fuse_on_gpu", &canFuseOnGPU)
721 .def("_jit_can_fuse_on_cpu_legacy", &canFuseOnCPULegacy)
722 .def("_jit_override_can_fuse_on_cpu_legacy", &overrideCanFuseOnCPULegacy)
723 .def(
724 "_jit_differentiate",
725 [](Graph& g) {
726 // the python binding slightly differs in semantics
727 // it makes a copy of the input Graph, and works on that
728 // jit::differentiate mutates the input Graph
729 auto g_clone = g.copy();
730 return differentiate(g_clone);
731 })
732 .def(
733 "_jit_check_alias_annotation",
734 [](const std::shared_ptr<Graph>& g,
735 const py::tuple& args,
736 const std::string& unqualified_op_name) {
737 auto stack = toTraceableStack(args);
738 checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
739 })
740 #if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
741 .def("_jit_set_llga_enabled", &RegisterLlgaFuseGraph::setEnabled)
742 .def("_jit_llga_enabled", &RegisterLlgaFuseGraph::isEnabled)
743 #else
744 .def("_jit_set_llga_enabled", [](bool flag) { return false; })
745 .def("_jit_llga_enabled", []() { return false; })
746 #endif
747 .def(
748 "_jit_set_tracer_state_warn",
749 [](bool new_warn) {
750 jit::tracer::getTracerStateWarnMode() = new_warn;
751 })
752 .def(
753 "_jit_get_tracer_state_warn",
754 []() {
755 bool current_tracer_warn = jit::tracer::getTracerStateWarnMode();
756 return current_tracer_warn;
757 })
758 .def(
759 "_jit_set_nvfuser_skip_node_kind",
760 [](const std::string& op_name, bool flip = true) {
761 TORCH_WARN(
762 "nvfuser is no longer supported in torch script, use _jit_set_nvfuser_skip_node_kind is deprecated and a no-op");
763 })
764 .def(
765 "_jit_set_nvfuser_enabled",
766 [](bool) {
767 TORCH_WARN(
768 "nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op");
769 })
770 .def(
771 "_jit_nvfuser_can_be_enabled",
772 []() {
773 TORCH_WARN(
774 "nvfuser is no longer supported in torch script, use _jit_nvfuser_can_be_enabled is deprecated and a no-op");
775 })
776 .def(
777 "_jit_set_nvfuser_single_node_mode",
778 [](bool) {
779 TORCH_WARN(
780 "nvfuser is no longer supported in torch script, use _jit_set_nvfuser_single_node_mode is deprecated and a no-op");
781 })
782 .def(
783 "_jit_nvfuser_single_node_mode",
784 []() {
785 TORCH_WARN(
786 "nvfuser is no longer supported in torch script, use _jit_nvfuser_single_node_mode is deprecated and a no-op");
787 })
788 .def(
789 "_jit_set_nvfuser_horizontal_mode",
790 [](bool) {
791 TORCH_WARN(
792 "nvfuser is no longer supported in torch script, use _jit_set_nvfuser_horizontal_mode is deprecated and a no-op");
793 })
794 .def(
795 "_jit_nvfuser_horizontal_mode",
796 []() {
797 TORCH_WARN(
798 "nvfuser is no longer supported in torch script, use _jit_nvfuser_horizontal_mode is deprecated and a no-op");
799 })
800 .def(
801 "_jit_set_nvfuser_guard_mode",
802 [](bool) {
803 TORCH_WARN(
804 "nvfuser is no longer supported in torch script, use _jit_set_nvfuser_guard_mode is deprecated and a no-op");
805 })
806 .def(
807 "_jit_nvfuser_enabled",
808 []() {
809 TORCH_WARN(
810 "nvfuser is no longer supported in torch script, use _jit_nvfuser_enabled is deprecated and a no-op");
811 })
812 .def(
813 "_jit_nvfuser_set_comparison_callback",
814 [](bool, py::function) {
815 TORCH_WARN(
816 "nvfuser is no longer supported in torch script, use _jit_nvfuser_set_comparison_callback is deprecated and a no-op");
817 })
818 .def(
819 "_jit_nvfuser_clear_comparison_callback",
820 []() {
821 TORCH_WARN(
822 "nvfuser is no longer supported in torch script, use _jit_nvfuser_clear_comparison_callback is deprecated and a no-op");
823 })
824 .def(
825 "_jit_set_profiling_mode",
826 [](bool profiling_flag) {
827 bool oldState = getProfilingMode();
828 getProfilingMode() = profiling_flag;
829 return oldState;
830 })
831 .def(
832 "_jit_set_profiling_executor",
833 [](bool profiling_flag) {
834 bool oldState = getExecutorMode();
835 getExecutorMode() = profiling_flag;
836 return oldState;
837 })
838 .def(
839 "_jit_set_num_profiled_runs",
840 [](size_t num) {
841 size_t old_num = getNumProfiledRuns();
842 getNumProfiledRuns() = num;
843 return old_num;
844 })
845 .def(
846 "_jit_get_num_profiled_runs",
847 [] {
848 // pybind can't automatically bind to atomic size_t
849 size_t num_runs = getNumProfiledRuns();
850 return num_runs;
851 })
852 .def(
853 "_jit_set_bailout_depth",
854 [](size_t depth) {
855 TORCH_WARN(
856 "Use _jit_set_fusion_strategy, bailout depth is deprecated. Setting to (STATIC, ",
857 depth,
858 ")");
859 size_t old_depth = getBailoutDepth();
860 FusionStrategy strat = {{FusionBehavior::STATIC, depth}};
861 setFusionStrategy(strat);
862 return old_depth;
863 })
864 .def(
865 "_jit_set_fusion_strategy",
866 [](const std::vector<std::pair<std::string, size_t>>& strategy) {
867 FusionStrategy vec_conv;
868 for (const auto& pair : strategy) {
869 if (pair.first == "STATIC") {
870 vec_conv.emplace_back(FusionBehavior::STATIC, pair.second);
871 } else if (pair.first == "DYNAMIC") {
872 vec_conv.emplace_back(FusionBehavior::DYNAMIC, pair.second);
873 } else {
874 TORCH_INTERNAL_ASSERT(
875 false,
876 "FusionBehavior only supported 'STATIC' or 'DYNAMIC', got: ",
877 pair.first);
878 }
879 }
880 auto old_strategy = getFusionStrategy();
881 auto strat =
882 fmap(old_strategy, [](std::pair<FusionBehavior, size_t> behav) {
883 return std::pair<std::string, size_t>(
884 behav.first == FusionBehavior::STATIC ? "STATIC"
885 : "DYNAMIC",
886 behav.second);
887 });
888 setFusionStrategy(vec_conv);
889 return strat;
890 })
891 .def(
892 "_jit_set_inline_everything_mode",
893 [](bool enabled) { getInlineEverythingMode() = enabled; })
894 .def(
895 "_jit_get_inline_everything_mode",
896 []() { return getInlineEverythingMode(); })
897 .def(
898 "_jit_get_logging_option",
899 []() { return ::torch::jit::get_jit_logging_levels(); })
900 .def(
901 "_jit_set_logging_option",
902 [](std::string loggingOption) -> void {
903 ::torch::jit::set_jit_logging_levels(std::move(loggingOption));
904 })
905 .def(
906 "_jit_set_logging_stream",
907 [](const std::string& stream_name) -> void {
908 if (stream_name == "stdout") {
909 ::torch::jit::set_jit_logging_output_stream(std::cout);
910 } else if (stream_name == "stderr") {
911 ::torch::jit::set_jit_logging_output_stream(std::cerr);
912 } else {
913 std::cerr << "ERROR: only `stdout` and `stderr`"
914 << "are supported as output options" << '\n';
915 }
916 })
917 .def(
918 "_storage_id",
919 [](const at::Tensor& ten) -> int64_t {
920 return reinterpret_cast<int64_t>(
921 ten.storage().unsafeGetStorageImpl());
922 })
923 .def(
924 "_jit_try_infer_type",
925 [](py::object obj) -> InferredType {
926 return tryToInferType(std::move(obj));
927 })
928 .def(
929 "_jit_get_te_cuda_pointwise_loop_levels",
930 []() -> int {
931 using namespace torch::jit::tensorexpr;
932 return getTECudaPointwiseLoopLevels();
933 })
934 .def(
935 "_jit_set_te_cuda_pointwise_loop_levels",
936 [](int level) {
937 using namespace torch::jit::tensorexpr;
938 return getTECudaPointwiseLoopLevels() = level;
939 })
940 .def(
941 "_jit_get_te_cuda_pointwise_block_count",
942 []() -> int {
943 using namespace torch::jit::tensorexpr;
944 return getTECudaPointwiseBlockCount();
945 })
946 .def(
947 "_jit_set_te_cuda_pointwise_block_count",
948 [](int block_count) {
949 using namespace torch::jit::tensorexpr;
950 return getTECudaPointwiseBlockCount() = block_count;
951 })
952 .def(
953 "_jit_get_te_cuda_pointwise_block_size",
954 []() -> int {
955 using namespace torch::jit::tensorexpr;
956 return getTECudaPointwiseBlockSize();
957 })
958 .def(
959 "_jit_set_te_cuda_pointwise_block_size",
960 [](int block_size) {
961 using namespace torch::jit::tensorexpr;
962 return getTECudaPointwiseBlockSize() = block_size;
963 })
964 .def("_jit_set_texpr_fuser_enabled", &setTensorExprFuserEnabled)
965 .def("_jit_texpr_fuser_enabled", &tensorExprFuserEnabled)
966 .def("_jit_texpr_fallback_allowed", &tensorexpr::fallbackAllowed)
967 .def("_jit_texpr_set_fallback_allowed", &tensorexpr::setFallbackAllowed)
968 .def("_jit_set_texpr_reductions_enabled", &setTexprReductionsEnabled)
969 .def(
970 "_jit_set_texpr_dynamic_shape_enabled",
971 &setTensorExprDynamicShapeFusionEnabled)
972 .def(
973 "_jit_texpr_dynamic_shape_enabled",
974 &tensorExprDynamicShapeFusionEnabled)
975 .def("_jit_texpr_reductions_enabled", &texprReductionsEnabled)
976 .def(
977 "_jit_set_te_generate_block_code",
978 [](bool gen_block_code) {
979 using namespace torch::jit::tensorexpr;
980 return getTEGenerateBlockCode() = gen_block_code;
981 })
982 .def(
983 "_jit_get_te_generate_block_code",
984 []() -> bool {
985 using namespace torch::jit::tensorexpr;
986 return getTEGenerateBlockCode();
987 })
988 .def(
989 "_jit_get_te_must_use_llvm_cpu",
990 []() -> bool {
991 using namespace torch::jit::tensorexpr;
992 return getTEMustUseLLVMOnCPU();
993 })
994 .def(
995 "_jit_set_te_must_use_llvm_cpu",
996 [](bool use_llvm) {
997 using namespace torch::jit::tensorexpr;
998 getTEMustUseLLVMOnCPU() = use_llvm;
999 })
1000 .def(
1001 "_jit_cat_wo_conditionals",
1002 [](bool optimize_cat) {
1003 using namespace torch::jit::tensorexpr;
1004 getCatWoConditionals() = optimize_cat;
1005 })
1006 .def(
1007 "_jit_opt_conditionals",
1008 [](bool opt_conds) {
1009 using namespace torch::jit::tensorexpr;
1010 getOptConditionals() = opt_conds;
1011 })
1012 .def(
1013 "_llvm_enabled",
1014 []() {
1015 #ifdef TORCH_ENABLE_LLVM
1016 return true;
1017 #else
1018 return false;
1019 #endif
1020 })
1021 .def(
1022 "_jit_pass_fuse_tensorexprs",
1023 [](std::shared_ptr<Graph>& g) {
1024 FuseTensorExprs(g);
1025 RemoveTensorTypeSpecializations(g);
1026 })
1027 .def(
1028 "_jit_fuser_get_fused_kernel_code",
1029 [](Graph& g, const std::vector<at::Tensor>& inps) {
1030 return debugGetFusedKernelCode(g, inps);
1031 })
1032 .def(
1033 "_jit_pass_remove_dropout",
1034 [](script::Module& module) { return removeDropout(module); })
1035 .def(
1036 "_jit_pass_refine_tuple_types",
1037 [](std::shared_ptr<Graph>& graph) { return RefineTupleTypes(graph); })
1038 .def(
1039 "_jit_pass_transform_conv1d_to_conv2d",
1040 [](std::shared_ptr<Graph>& graph) {
1041 return transformConv1dToConv2d(graph);
1042 })
1043 .def(
1044 "_jit_pass_transform_conv1d_to_conv2d",
1045 [](script::Module& module) {
1046 return transformConv1dToConv2d(module);
1047 })
1048 .def(
1049 "_jit_pass_insert_prepacked_ops",
1050 [](std::shared_ptr<Graph>& graph) {
1051 return insertPrePackedOps(graph);
1052 })
1053 .def(
1054 "_jit_pass_insert_prepacked_ops",
1055 [](script::Module& module) { return insertPrePackedOps(module); })
1056 .def(
1057 "_jit_pass_fuse_clamp_w_prepacked_linear_conv",
1058 [](script::Module& module) {
1059 return fusePrePackedLinearConvWithClamp(module);
1060 })
1061 .def(
1062 "_jit_pass_fold_prepacking_ops",
1063 [](script::Module& module) { return FoldPrePackingOps(module); })
1064 .def(
1065 "_jit_pass_optimize_for_mobile",
1066 [](script::Module& module,
1067 std::set<MobileOptimizerType>& optimization_blocklist,
1068 std::vector<std::string>& preserved_methods) {
1069 return optimizeForMobile(
1070 module, optimization_blocklist, preserved_methods);
1071 })
1072 .def(
1073 "_hack_do_not_use_clone_module_with_class",
1074 [](script::Module& module,
1075 std::vector<std::string>& ignored_methods,
1076 std::vector<std::string>& ignored_attributes) {
1077 const bool inplace = false;
1078 const std::unordered_set<std::string> ignored_methods_set(
1079 ignored_methods.begin(), ignored_methods.end());
1080 const std::unordered_set<std::string> ignored_attributes_set(
1081 ignored_attributes.begin(), ignored_attributes.end());
1082 return module.clone(
1083 inplace, ignored_methods_set, ignored_attributes_set);
1084 })
1085 .def(
1086 "_jit_pass_vulkan_insert_prepacked_ops",
1087 [](std::shared_ptr<Graph>& graph) {
1088 return vulkanInsertPrePackedOps(graph);
1089 })
1090 .def(
1091 "_jit_pass_vulkan_insert_prepacked_ops",
1092 [](script::Module& module) {
1093 return vulkanInsertPrePackedOps(module);
1094 })
1095 .def(
1096 "_jit_pass_vulkan_fuse_clamp_w_prepacked_conv",
1097 [](script::Module& module) {
1098 return vulkanFusePrePackedConvWithClamp(module);
1099 })
1100 .def(
1101 "_jit_pass_vulkan_fold_prepacking_ops",
1102 [](script::Module& module) {
1103 return vulkanFoldPrePackingOps(module);
1104 })
1105 .def(
1106 "_jit_pass_vulkan_optimize_for_mobile",
1107 [](script::Module& module,
1108 std::set<MobileOptimizerType>& optimization_blocklist,
1109 std::vector<std::string>& preserved_methods) {
1110 return vulkanOptimizeForMobile(
1111 module, optimization_blocklist, preserved_methods);
1112 })
1113 .def(
1114 "_jit_pass_metal_insert_prepacked_ops",
1115 [](std::shared_ptr<Graph>& graph) {
1116 return metalInsertPrePackedOps(graph);
1117 })
1118 .def(
1119 "_jit_pass_metal_insert_prepacked_ops",
1120 [](script::Module& module) {
1121 return metalInsertPrePackedOps(module);
1122 })
1123 .def(
1124 "_jit_pass_metal_fuse_clamp_w_prepacked_conv",
1125 [](script::Module& module) {
1126 return metalFusePrePackedConvWithClamp(module);
1127 })
1128 .def(
1129 "_jit_pass_metal_fold_prepacking_ops",
1130 [](script::Module& module) { return metalFoldPrePackingOps(module); })
1131 .def(
1132 "_jit_pass_metal_optimize_for_mobile",
1133 [](script::Module& module,
1134 std::vector<std::string>& preserved_methods) {
1135 return metalOptimizeForMobile(module, preserved_methods);
1136 })
1137 .def(
1138 "_jit_pass_filter_non_tensor_arguments",
1139 [](std::map<std::string, IValue> params) {
1140 std::map<std::string, at::Tensor> retval;
1141 for (auto& kv : params) {
1142 if (kv.second.isTensor()) {
1143 retval[kv.first] = std::move(kv.second).toTensor();
1144 }
1145 }
1146 return retval;
1147 })
1148 .def("_jit_pass_batch_mm", BatchMM)
1149 .def(
1150 "_jit_decay_packed_param_input_types",
1151 [](Graph& g) {
1152 for (Value* i : g.inputs()) {
1153 if (i->type() ==
1154 getCustomClass(
1155 "__torch__.torch.classes.quantized.Conv2dPackedParamsBase") ||
1156 i->type() ==
1157 getCustomClass(
1158 "__torch__.torch.classes.quantized.Conv3dPackedParamsBase") ||
1159 i->type() ==
1160 getCustomClass(
1161 "__torch__.torch.classes.quantized.LinearPackedParamsBase")) {
1162 // Dummy CompleteTensorType to appease ONNX validator.
1163 i->setType(TensorType::create(
1164 at::kQInt8,
1165 c10::kCPU,
1166 std::vector<int64_t>{1},
1167 std::vector<int64_t>{1},
1168 std::nullopt));
1169 }
1170 }
1171 })
1172 .def("_jit_set_utf8_decoding_ignore", &setUTF8DecodingIgnore);
1173
1174 // NB: This isn't actually used for regular PyTorch symbolic tracing;
1175 // XLA is what needs this
1176 #define SYMNODE_UNARY(n) .def(#n, [](const c10::SymNode& a) { return a->n(); })
1177 #define SYMNODE_BINARY(n) \
1178 .def(#n, [](const c10::SymNode& a, const c10::SymNode& b) { return a->n(b); })
1179 #define SYMNODE_SIZES_STRIDES(n) \
1180 .def( \
1181 #n, \
1182 [](const c10::SymNode& a, \
1183 c10::ArrayRef<c10::SymNode> sizes, \
1184 c10::ArrayRef<c10::SymNode> strides) { \
1185 return a->n(sizes, strides); \
1186 })
1187 auto symnode_class =
1188 py::class_<c10::SymNodeImpl, c10::SymNode>(m, "_SymNode")
1189 // clang-format off
1190 // These DO NOT install magic methods; the SymInt/SymFloat wrapper in
1191 // Python is responsible for this
1192 SYMNODE_UNARY(clone)
1193 SYMNODE_UNARY(is_int)
1194 SYMNODE_UNARY(is_float)
1195 SYMNODE_UNARY(is_bool)
1196 SYMNODE_UNARY(bool_)
1197 SYMNODE_UNARY(int_)
1198 SYMNODE_UNARY(sym_float)
1199 SYMNODE_BINARY(add)
1200 SYMNODE_BINARY(sub)
1201 SYMNODE_BINARY(mul)
1202 SYMNODE_BINARY(truediv)
1203 SYMNODE_BINARY(int_truediv)
1204 SYMNODE_BINARY(float_truediv)
1205 SYMNODE_BINARY(pow)
1206 SYMNODE_BINARY(float_pow)
1207 SYMNODE_BINARY(pow_by_natural)
1208 SYMNODE_BINARY(floordiv)
1209 SYMNODE_BINARY(int_floordiv)
1210 SYMNODE_BINARY(mod)
1211 SYMNODE_BINARY(eq)
1212 SYMNODE_BINARY(ne)
1213 SYMNODE_BINARY(gt)
1214 SYMNODE_BINARY(lt)
1215 SYMNODE_BINARY(le)
1216 SYMNODE_BINARY(ge)
1217 SYMNODE_BINARY(sym_min)
1218 SYMNODE_BINARY(sym_max)
1219 SYMNODE_BINARY(sym_and)
1220 SYMNODE_BINARY(sym_or)
1221 SYMNODE_UNARY(sym_not)
1222 SYMNODE_UNARY(ceil)
1223 SYMNODE_UNARY(floor)
1224 SYMNODE_UNARY(neg)
1225 SYMNODE_SIZES_STRIDES(is_contiguous)
1226 SYMNODE_SIZES_STRIDES(is_channels_last_contiguous_2d)
1227 SYMNODE_SIZES_STRIDES(is_channels_last_contiguous_3d)
1228 SYMNODE_SIZES_STRIDES(is_channels_last_strides_2d)
1229 SYMNODE_SIZES_STRIDES(is_channels_last_strides_3d)
1230 SYMNODE_SIZES_STRIDES(is_non_overlapping_and_dense)
1231 .def(
1232 "guard_int",
1233 [](const c10::SymNode& a, const char* file, int64_t line) {
1234 return a->guard_int(file, line);
1235 })
1236 .def(
1237 "guard_bool",
1238 [](const c10::SymNode& a, const char* file, int64_t line) {
1239 return a->guard_bool(file, line);
1240 })
1241 .def(
1242 "guard_float",
1243 [](const c10::SymNode& a, const char* file, int64_t line) {
1244 return a->guard_float(file, line);
1245 })
1246 .def(
1247 "expect_true",
1248 [](const c10::SymNode& a, const char* file, int64_t line) {
1249 return a->expect_true(file, line);
1250 })
1251 .def(
1252 "expect_size",
1253 [](const c10::SymNode& a, const char* file, int64_t line) {
1254 return a->expect_size(file, line);
1255 })
1256 .def(
1257 "guard_size_oblivious",
1258 [](const c10::SymNode& a, const char* file, int64_t line) {
1259 return a->guard_size_oblivious(file, line);
1260 })
1261 .def(
1262 "has_hint",
1263 [](const c10::SymNode& a) {
1264 return a->has_hint();
1265 })
1266 .def(
1267 "wrap_int",
1268 [](const c10::SymNode& a, int64_t b) {
1269 return a->wrap_int(b);
1270 })
1271 .def(
1272 "wrap_float",
1273 [](const c10::SymNode& a, double b) {
1274 return a->wrap_float(b);
1275 })
1276 .def(
1277 "wrap_bool",
1278 [](const c10::SymNode& a, bool b) {
1279 return a->wrap_bool(b);
1280 })
1281 .def(
1282 "__str__",
1283 [](const c10::SymNode& a) { return a->str(); })
1284 .def(
1285 "__repr__",
1286 [](const c10::SymNode& a) { return a->str(); })
1287 .def(
1288 "_graph_repr",
1289 [](const c10::SymNode& a) { return a->_graph_repr(); })
1290 .def(
1291 "is_constant",
1292 [](const c10::SymNode& node){
1293 return node->is_constant();
1294 })
1295 .def(
1296 "is_nested_int",
1297 [](const c10::SymNode& node) {
1298 return node->is_nested_int();
1299 })
1300 .def(
1301 "is_symbolic",
1302 [](const c10::SymNode& node) {
1303 return node->is_symbolic();
1304 })
1305 .def(
1306 "nested_int",
1307 [](const c10::SymNode& node) {
1308 return node->nested_int();
1309 })
1310 .def(
1311 "nested_int_coeff",
1312 [](const c10::SymNode& node) {
1313 return node->nested_int_coeff();
1314 })
1315 .def(
1316 "__deepcopy__",
1317 [](const c10::SymNode& node, py::handle memo) {
1318 return node->clone();
1319 });
1320
1321 // clang-format on
1322
1323 // NOLINTNEXTLINE(bugprone-unused-raii)
1324 py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
1325 .def("__repr__", [](CompleteArgumentSpec& self) {
1326 std::ostringstream s;
1327 s << self;
1328 return s.str();
1329 });
1330 // NOLINTNEXTLINE(bugprone-unused-raii)
1331 py::class_<ArgumentSpec>(m, "ArgumentSpec");
1332 py::class_<Code>(m, "Code")
1333 .def(
1334 "grad_executor_states",
1335 [](Code& c) {
1336 std::vector<GraphExecutorState> states;
1337 for (auto& e : c.grad_executors()) {
1338 states.emplace_back(e->getDebugState());
1339 }
1340 return states;
1341 })
1342 .def(
1343 "differentiable_op_executor_states",
1344 [](Code& c) {
1345 std::vector<GraphExecutorState> states;
1346 for (auto& e : c.diff_graph_op_executors()) {
1347 if (e->isOptimized()) {
1348 states.emplace_back(e->getDebugState());
1349 } else {
1350 // we leave an empty entry for node that doesn't have an
1351 // optimized plan
1352 states.emplace_back();
1353 }
1354 }
1355 return states;
1356 })
1357 .def("num_bailouts", [](Code& c) { return c.num_bailouts(); })
1358 .def("request_bailout", [](Code& c, size_t index) {
1359 c.request_bailout(index);
1360 });
1361
1362 py::class_<ExecutionPlan>(m, "ExecutionPlan")
1363 .def_property_readonly("graph", [](ExecutionPlan& s) { return s.graph; })
1364 .def_property_readonly("code", [](ExecutionPlan& s) { return s.code; });
1365
1366 py::class_<Gradient>(m, "Gradient")
1367 .def_property_readonly("f", [](Gradient& m) { return m.f; })
1368 .def_property_readonly("df", [](Gradient& m) { return m.df; })
1369 .def_property_readonly(
1370 "f_real_outputs", [](Gradient& m) { return m.f_real_outputs; })
1371 .def_property_readonly(
1372 "df_input_vjps", [](Gradient& m) { return m.df_input_vjps; })
1373 .def_property_readonly(
1374 "df_input_captured_inputs",
1375 [](Gradient& m) { return m.df_input_captured_inputs; })
1376 .def_property_readonly(
1377 "df_input_captured_outputs",
1378 [](Gradient& m) { return m.df_input_captured_outputs; })
1379 .def_property_readonly(
1380 "df_output_vjps", [](Gradient& m) { return m.df_output_vjps; });
1381
1382 py::class_<GraphExecutorState>(m, "GraphExecutorState")
1383 .def_property_readonly(
1384 "graph", [](GraphExecutorState& s) { return s.graph; })
1385 .def_property_readonly(
1386 "execution_plans",
1387 [](GraphExecutorState& s) { return s.execution_plans; })
1388 .def_property_readonly(
1389 "fallback", [](GraphExecutorState& s) { return s.fallback; });
1390
1391 py::class_<PyTorchStreamWriter>(m, "PyTorchFileWriter")
1392 .def(py::init<std::string>())
1393 .def(py::init([](const py::object& buffer) {
1394 auto writer_func = [=](const void* data, size_t size) {
1395 // Writing an empty file is a noop
1396 if (size == 0) {
1397 return size;
1398 }
1399 py::gil_scoped_acquire acquire;
1400 if (!data) {
1401 // See [Note: write_record_metadata]
1402 buffer.attr("seek")(
1403 size, py::module::import("os").attr("SEEK_CUR"));
1404 } else {
1405 auto memory_view = py::memoryview::from_memory(
1406 reinterpret_cast<const char*>(data), size);
1407 buffer.attr("write")(std::move(memory_view));
1408 }
1409 return size;
1410 };
1411 return std::make_unique<PyTorchStreamWriter>(std::move(writer_func));
1412 }))
1413 .def(py::init<const std::function<size_t(const void*, size_t)>&>())
1414 // [Note: write_record_metadata]
1415 // The write_record_metadata function is intended to write metadata (i.e.
1416 // the zipfile header and end of central directory record) for a file
1417 // while reserving nbytes of space for the file for the bytes of the
1418 // actual file to be added in later. This functionality is achieved by
1419 // defining `m_pWrite` to seek instead of write if the buffer passed is a
1420 // nullptr. This has implications on CRC-32 which will not be written at
1421 // write_record_metadata time, and will not be combined with the hash in
1422 // combined_uncomp_crc32_. We define this in `m_pWrite` rather than
1423 // extending the interface of miniz to have an `m_pSeek` since different
1424 // versions of miniz are used in fbcode/oss.
1425 .def(
1426 "write_record_metadata",
1427 [](PyTorchStreamWriter& self, const std::string& name, size_t size) {
1428 return self.writeRecord(name, nullptr, size);
1429 })
1430 .def(
1431 "write_record",
1432 [](PyTorchStreamWriter& self,
1433 const std::string& name,
1434 const char* data,
1435 size_t size) {
1436 // Since we don't know where the data come from, we cannot
1437 // release the GIL in this overload
1438 return self.writeRecord(name, data, size);
1439 })
1440 .def(
1441 "write_record",
1442 [](PyTorchStreamWriter& self,
1443 const std::string& name,
1444 py::bytes data,
1445 size_t size) {
1446 // It is not clear from the doc but according to CPython own code,
1447 // it is ok to use the result of PyBytes_AsString without the GIL
1448 // being held
1449 // https://github.com/python/cpython/blob/e2a3e4b7488aff6fdc704a0f258bc315e96c1d6e/Objects/stringlib/join.h#L67
1450 const char* data_str = PyBytes_AsString(data.ptr());
1451 py::gil_scoped_release release;
1452 return self.writeRecord(name, data_str, size);
1453 })
1454 .def(
1455 "write_record",
1456 [](PyTorchStreamWriter& self,
1457 const std::string& name,
1458 const c10::Storage& data,
1459 size_t size) {
1460 // Reading Tensor data is always ok without the GIL held
1461 py::gil_scoped_release release;
1462 return self.writeRecord(
1463 name, reinterpret_cast<const char*>(data.data()), size);
1464 })
1465 .def(
1466 "write_record",
1467 [](PyTorchStreamWriter& self,
1468 const std::string& name,
1469 uintptr_t data,
1470 size_t size) {
1471 TORCH_WARN_ONCE(
1472 "write_record(): Passing Storage by data pointer is deprecated and will be an error in ",
1473 "the future, please pass the Storage object instead.");
1474 return self.writeRecord(
1475 name, reinterpret_cast<const char*>(data), size);
1476 })
1477 .def("write_end_of_file", &PyTorchStreamWriter::writeEndOfFile)
1478 .def("set_min_version", &PyTorchStreamWriter::setMinVersion)
1479 .def("archive_name", &PyTorchStreamWriter::archiveName)
1480 .def("serialization_id", &PyTorchStreamWriter::serializationId)
1481 .def(
1482 "get_all_written_records",
1483 &PyTorchStreamWriter::getAllWrittenRecords);
1484
1485 py::enum_<MobileOptimizerType>(m, "_MobileOptimizerType")
1486 .value("CONV_BN_FUSION", MobileOptimizerType::CONV_BN_FUSION)
1487 .value(
1488 "INSERT_FOLD_PREPACK_OPS",
1489 MobileOptimizerType::INSERT_FOLD_PREPACK_OPS)
1490 .value("REMOVE_DROPOUT", MobileOptimizerType::REMOVE_DROPOUT)
1491 .value("FUSE_ADD_RELU", MobileOptimizerType::FUSE_ADD_RELU)
1492 .value(
1493 "HOIST_CONV_PACKED_PARAMS",
1494 MobileOptimizerType::HOIST_CONV_PACKED_PARAMS)
1495 .value(
1496 "VULKAN_AUTOMATIC_GPU_TRANSFER",
1497 MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER);
1498
1499 // This allows PyTorchStreamReader to read from a Python buffer. It requires
1500 // that the buffer implement `seek()`, `tell()`, and `read()`.
1501 class BufferAdapter : public caffe2::serialize::ReadAdapterInterface {
1502 public:
1503 BufferAdapter(const py::object& buffer) : buffer_(buffer) {
1504 // Jump to the end of the buffer to get its size
1505 auto current = buffer.attr("tell")();
1506 start_offset_ = py::cast<size_t>(current);
1507 buffer.attr("seek")(current, py::module::import("os").attr("SEEK_END"));
1508 size_ = py::cast<size_t>(buffer.attr("tell")()) - start_offset_;
1509 buffer.attr("seek")(current);
1510 // If we can read directly into a buffer, do that instead of an extra copy
1511 // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
1512 use_readinto_ = py::hasattr(buffer, "readinto");
1513 }
1514
1515 size_t size() const override {
1516 return size_;
1517 }
1518
1519 THPObjectPtr getMemview(void* buf, size_t n) const {
1520 THPObjectPtr memview(PyMemoryView_FromMemory(
1521 reinterpret_cast<char*>(buf), n, PyBUF_WRITE));
1522 if (!memview) {
1523 throw python_error();
1524 }
1525 return memview;
1526 }
1527
1528 size_t read(uint64_t pos, void* buf, size_t n, const char* what)
1529 const override {
1530 // Seek to desired position (NB: this has to be a Py_ssize_t or Python
1531 // throws a weird error)
1532 Py_ssize_t absolute_pos = start_offset_ + pos;
1533 buffer_.attr("seek")(absolute_pos);
1534
1535 if (use_readinto_) {
1536 auto memview = getMemview(buf, n);
1537 auto res =
1538 PyObject_CallMethod(buffer_.ptr(), "readinto", "O", memview.get());
1539 if (res) {
1540 int64_t i = static_cast<int64_t>(PyLong_AsLongLong(res));
1541 Py_DECREF(res);
1542 if (i > 0) {
1543 return i;
1544 }
1545 }
1546 }
1547
1548 // Read bytes into `buf` from the buffer
1549 std::string bytes = py::cast<std::string>(buffer_.attr("read")(n));
1550 std::copy(
1551 bytes.data(),
1552 bytes.data() + bytes.size(),
1553 reinterpret_cast<char*>(buf));
1554 return bytes.size();
1555 }
1556
1557 py::object buffer_;
1558 size_t size_;
1559 size_t start_offset_;
1560 bool use_readinto_{};
1561 };
1562
1563 py::class_<PyTorchStreamReader, std::shared_ptr<PyTorchStreamReader>>(
1564 m, "PyTorchFileReader")
1565 .def(py::init<std::string>())
1566 .def(py::init([](const py::object& buffer) {
1567 auto adapter = std::make_unique<BufferAdapter>(buffer);
1568 return std::make_shared<PyTorchStreamReader>(std::move(adapter));
1569 }))
1570 .def(
1571 "get_record",
1572 [](PyTorchStreamReader& self, const std::string& key) {
1573 auto [data, size] = self.getRecord(key);
1574 return py::bytes(reinterpret_cast<const char*>(data.get()), size);
1575 })
1576 .def(
1577 "has_record",
1578 [](PyTorchStreamReader& self, const std::string& key) {
1579 return self.hasRecord(key);
1580 })
1581 .def(
1582 "get_storage_from_record",
1583 [](PyTorchStreamReader& self,
1584 const std::string& key,
1585 size_t numel,
1586 py::object data_type_obj) {
1587 at::DataPtr data(std::get<0>(self.getRecord(key)));
1588 auto scalar_type =
1589 reinterpret_cast<THPDtype*>(data_type_obj.ptr())->scalar_type;
1590
1591 c10::Storage storage(
1592 c10::Storage::use_byte_size_t(),
1593 numel * elementSize(scalar_type),
1594 std::move(data),
1595 /*allocator=*/nullptr,
1596 /*resizable=*/false);
1597 auto ptr =
1598 c10::make_intrusive<at::TensorImpl, at::UndefinedTensorImpl>(
1599 std::move(storage),
1600 at::DispatchKeySet(),
1601 at::CPU(scalar_type).typeMeta());
1602 return at::Tensor(std::move(ptr));
1603 })
1604 .def("serialization_id", &PyTorchStreamReader::serializationId)
1605 .def(
1606 "get_all_records",
1607 [](PyTorchStreamReader& self) { return self.getAllRecords(); })
1608 .def(
1609 "get_record_offset",
1610 [](PyTorchStreamReader& self, const std::string& key) {
1611 return self.getRecordOffset(key);
1612 });
1613
1614 // Used by torch.Package to coordinate deserialization of storages across
1615 // ScriptModules and eager modules
1616 py::class_<
1617 DeserializationStorageContext,
1618 std::shared_ptr<DeserializationStorageContext>>(
1619 m, "DeserializationStorageContext")
1620 .def(py::init<>())
1621 .def(
1622 "get_storage",
1623 [](DeserializationStorageContext& self,
1624 const std::string& name,
1625 py::object data_type_obj) {
1626 c10::Storage storage = self.getStorage(name);
1627 auto scalar_type =
1628 reinterpret_cast<THPDtype*>(data_type_obj.ptr())->scalar_type;
1629 auto ptr =
1630 c10::make_intrusive<at::TensorImpl, at::UndefinedTensorImpl>(
1631 std::move(storage),
1632 at::DispatchKeySet(),
1633 at::CPU(scalar_type).typeMeta());
1634
1635 return at::Tensor(std::move(ptr));
1636 })
1637 .def(
1638 "add_storage",
1639 [](DeserializationStorageContext& self,
1640 const std::string& name,
1641 const at::Tensor& tensor) {
1642 return self.addStorage(name, tensor.storage());
1643 })
1644 .def("has_storage", &DeserializationStorageContext::hasStorage);
1645
1646 m.def(
1647 "_get_schema",
1648 [](const std::string& op_name, const std::string& overload_name) {
1649 try {
1650 auto symbol = Symbol::fromQualString(op_name);
1651 auto operations = getAllOperatorsFor(symbol);
1652 for (const auto& op : operations) {
1653 if (op->schema().overload_name() == overload_name) {
1654 return op->schema();
1655 }
1656 }
1657 throw std::runtime_error("Found no matching schema");
1658 } catch (const c10::Error& e) {
1659 auto msg = torch::get_cpp_stacktraces_enabled()
1660 ? e.what()
1661 : e.what_without_backtrace();
1662 throw std::runtime_error(msg);
1663 }
1664 });
1665
1666 m.def(
1667 "_get_operation_overload",
1668 [](const std::string& op_name,
1669 const std::string& overload_name) -> std::optional<py::tuple> {
1670 try {
1671 auto symbol = Symbol::fromQualString(op_name);
1672 auto operations = getAllOperatorsFor(symbol);
1673 bool allow_numbers_as_tensors = opAllowsNumbersAsTensors(symbol);
1674 for (const auto& op : operations) {
1675 if (op->schema().overload_name() == overload_name) {
1676 auto func = py::cpp_function(
1677 [op, symbol, allow_numbers_as_tensors](
1678 const py::args& args, const py::kwargs& kwargs) {
1679 ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
1680 return _get_operation_for_overload_or_packet(
1681 {op}, symbol, args, kwargs, /*is_overload*/ true);
1682 });
1683 auto func_dk =
1684 py::cpp_function([op, symbol, allow_numbers_as_tensors](
1685 c10::DispatchKey dk_,
1686 const py::args& args,
1687 const py::kwargs& kwargs) {
1688 std::optional<c10::DispatchKey> dk =
1689 std::make_optional(dk_);
1690 ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
1691 return _get_operation_for_overload_or_packet(
1692 {op}, symbol, args, kwargs, /*is_overload*/ true, dk);
1693 });
1694 return std::make_optional(
1695 py::make_tuple(func, func_dk, py::cast(op->getTags().vec())));
1696 }
1697 }
1698 return std::nullopt;
1699 } catch (const c10::Error& e) {
1700 auto msg = torch::get_cpp_stacktraces_enabled()
1701 ? e.what()
1702 : e.what_without_backtrace();
1703 throw std::runtime_error(msg);
1704 }
1705 });
1706
1707 m.def(
1708 "_check_schema_allow_fake_script_object",
1709 [](const FunctionSchema& schema,
1710 const py::args& args,
1711 const py::kwargs& kwargs) {
1712 // checkSchemaAllowFakeScriptObject will throw runtime error if there is
1713 // a schema mismatch. Otherwise, it returns true.
1714 return checkSchemaAllowFakeScriptObject(schema, args, kwargs);
1715 });
1716
1717 m.def(
1718 "_jit_resolve_packet",
1719 [](const char* op_name, py::args args, const py::kwargs& kwargs) {
1720 try {
1721 auto symbol = Symbol::fromQualString(op_name);
1722 bool allow_numbers_as_tensors = opAllowsNumbersAsTensors(symbol);
1723 ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
1724 const auto overloads = getAllSortedOperatorsFor(symbol);
1725 auto opWithStack = getOpWithStack(overloads, std::move(args), kwargs);
1726 std::shared_ptr<Operator> overload = std::get<0>(opWithStack);
1727 auto result = overload->schema().overload_name();
1728 if (result.empty()) {
1729 result = "default";
1730 }
1731 return result;
1732 } catch (const c10::Error& e) {
1733 auto msg = torch::get_cpp_stacktraces_enabled()
1734 ? e.what()
1735 : e.what_without_backtrace();
1736 throw std::runtime_error(msg);
1737 }
1738 });
1739
1740 m.def(
1741 "_jit_get_operation",
1742 [](const std::string& op_name) {
1743 try {
1744 auto symbol = Symbol::fromQualString(op_name);
1745 const auto sortedOps = getAllSortedOperatorsFor(symbol);
1746 if (sortedOps.empty()) {
1747 // No such operator
1748 return py::make_tuple(py::none(), py::none());
1749 }
1750
1751 std::ostringstream docstring;
1752 docstring << "Automatically bound operator '" << op_name
1753 << "' with schema(s):\n";
1754
1755 for (const auto& op : sortedOps) {
1756 docstring << " " << op->schema() << "\n";
1757 }
1758
1759 py::list overload_names;
1760 for (const auto& op : sortedOps) {
1761 overload_names.append(py::str(op->schema().overload_name()));
1762 }
1763
1764 bool allow_numbers_as_tensors = opAllowsNumbersAsTensors(symbol);
1765
1766 auto func = py::cpp_function(
1767 [sortedOps, symbol, allow_numbers_as_tensors](
1768 const py::args& args, const py::kwargs& kwargs) {
1769 ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
1770 return _get_operation_for_overload_or_packet(
1771 sortedOps, symbol, args, kwargs, false);
1772 },
1773 py::name(symbol.toUnqualString()),
1774 py::doc(docstring.str().c_str()));
1775 return py::make_tuple(func, overload_names);
1776 } catch (const c10::Error& e) {
1777 auto msg = torch::get_cpp_stacktraces_enabled()
1778 ? e.what()
1779 : e.what_without_backtrace();
1780 throw std::runtime_error(msg);
1781 }
1782 },
1783 py::arg("qualified_name"));
1784
1785 m.def(
1786 "_maybe_call_torch_function_for_op_packet",
1787 [](py::handle op_overload_packet,
1788 const py::args& args,
1789 const py::kwargs& kwargs) {
1790 py::list ns_method =
1791 op_overload_packet.attr("_qualified_op_name").attr("split")("::");
1792 auto res = _maybe_handle_torch_function(
1793 py::cast<std::string>(ns_method[0]),
1794 py::cast<std::string>(ns_method[1]),
1795 "",
1796 false,
1797 args,
1798 kwargs);
1799 if (res) {
1800 return py::make_tuple(true, *res);
1801 } else {
1802 return py::make_tuple(false, py::none());
1803 }
1804 });
1805
1806 m.def(
1807 "parse_ir",
1808 [](const std::string& input, bool parse_tensor_constants) {
1809 auto graph = std::make_shared<Graph>();
1810 parseIR(input, &*graph, parse_tensor_constants);
1811 return graph;
1812 },
1813 py::arg("input"),
1814 py::arg("parse_tensor_constants") = false);
1815 m.def(
1816 "parse_schema",
1817 &parseSchema,
1818 py::arg("schema"),
1819 py::arg("allow_typevars") = true);
1820 m.def("unify_type_list", [](const std::vector<TypePtr>& types) {
1821 std::ostringstream s;
1822 auto type = unifyTypeList(types, s);
1823 if (!type) {
1824 throw std::runtime_error(s.str());
1825 }
1826 return type.value();
1827 });
1828 py::enum_<SchemaArgType>(m, "_SchemaArgType")
1829 .value("input", SchemaArgType::input)
1830 .value("output", SchemaArgType::output);
1831 py::class_<SchemaArgument>(m, "_SchemaArgument")
1832 .def(py::init<SchemaArgType, size_t>())
1833 .def_readwrite("type", &SchemaArgument::type)
1834 .def_readwrite("index", &SchemaArgument::index);
1835 py::class_<SchemaInfo>(m, "_SchemaInfo")
1836 .def(py::init<FunctionSchema>())
1837 .def("is_mutable", [](SchemaInfo& self) { return self.is_mutable(); })
1838 .def(
1839 "is_mutable",
1840 [](SchemaInfo& self, const SchemaArgument& argument) {
1841 return self.is_mutable(argument);
1842 })
1843 .def(
1844 "has_argument",
1845 [](SchemaInfo& self, const std::string& name) {
1846 return self.has_argument(name);
1847 })
1848 .def(
1849 "is_mutable",
1850 [](SchemaInfo& self, const std::string& name) {
1851 return self.is_mutable(name);
1852 })
1853 .def(
1854 "may_alias",
1855 [](SchemaInfo& self,
1856 const SchemaArgument& lhs,
1857 const SchemaArgument& rhs) { return self.may_alias(lhs, rhs); })
1858 .def(
1859 "may_contain_alias",
1860 [](SchemaInfo& self,
1861 const SchemaArgument& lhs,
1862 const SchemaArgument& rhs) {
1863 return self.may_contain_alias(lhs, rhs);
1864 })
1865 .def(
1866 "add_argument_value",
1867 [](SchemaInfo& self,
1868 const std::string& name,
1869 const py::object& value) {
1870 std::optional<IValue> i_value = toTypeInferredIValueOptional(value);
1871 if (i_value) {
1872 // For normalization purposes there is an inconsistency within
1873 // torch.fx that turns all arguments named "self" into "input".
1874 // Thus this check ensures that those arguments are checked
1875 // correctly.
1876 if (name == "input" && !self.hasInputArgumentNamed("input")) {
1877 self.addArgumentValue("self", *i_value);
1878 } else {
1879 self.addArgumentValue(name, *i_value);
1880 }
1881 }
1882 })
1883 .def("add_argument_values", [](SchemaInfo& self, const py::dict& values) {
1884 std::unordered_map<std::string, IValue> value_map;
1885 for (const auto& key_pair : values) {
1886 IValue key = toTypeInferredIValue(key_pair.first);
1887 TORCH_INTERNAL_ASSERT(
1888 key.isString(),
1889 "Add argument value keys types should be strings.");
1890 std::optional<IValue> value =
1891 toTypeInferredIValueOptional(key_pair.second);
1892 if (value) {
1893 // For normalization purposes there is an inconsistency within
1894 // torch.fx that
1895 // turns all arguments named "self" into "input". Thus this check
1896 // ensures that those arguments are checked correctly.
1897 if (key.toStringRef() == "input" &&
1898 !self.hasInputArgumentNamed("input")) {
1899 self.addArgumentValue("self", *value);
1900 } else {
1901 value_map[key.toStringRef()] = *value;
1902 }
1903 }
1904 }
1905 self.addArgumentValues(value_map);
1906 });
1907 py::class_<FunctionSchema>(m, "FunctionSchema")
1908 .def_property_readonly(
1909 "name", [](FunctionSchema& self) { return self.name(); })
1910 .def_property_readonly(
1911 "overload_name",
1912 [](FunctionSchema& self) { return self.overload_name(); })
1913 .def_property_readonly(
1914 "arguments", [](FunctionSchema& self) { return self.arguments(); })
1915 .def_property_readonly(
1916 "returns", [](FunctionSchema& self) { return self.returns(); })
1917 .def(
1918 "is_backward_compatible_with",
1919 [](const FunctionSchema& self, const FunctionSchema& old_schema) {
1920 return self.isBackwardCompatibleWith(old_schema);
1921 })
1922 .def(
1923 "check_forward_compatible_with",
1924 [](const FunctionSchema& self, const FunctionSchema& old_schema) {
1925 std::ostringstream out;
1926 auto result = self.isForwardCompatibleWith(old_schema, out);
1927 return std::make_pair(result, out.str());
1928 })
1929 .def(
1930 "__eq__",
1931 [](const FunctionSchema& self, const FunctionSchema& other) {
1932 return self == other;
1933 })
1934 .def(
1935 "__hash__",
1936 [](const FunctionSchema& self) {
1937 return std::hash<FunctionSchema>{}(self);
1938 })
1939 .def(
1940 "__str__",
1941 [](FunctionSchema& self) {
1942 std::stringstream ss;
1943 ss << self;
1944 return ss.str();
1945 })
1946 .def(
1947 "__repr__",
1948 [](FunctionSchema& self) {
1949 std::stringstream ss;
1950 ss << self;
1951 return ss.str();
1952 })
1953 .def(py::pickle(
1954 [](const FunctionSchema& self) { // __getstate__
1955 std::stringstream ss;
1956 ss << self;
1957 return py::str(ss.str());
1958 },
1959 [](const py::str& schema) { // __setstate__, note: no `self` argument
1960 return parseSchema(schema);
1961 }))
1962 .def_property_readonly(
1963 "is_mutable", [](FunctionSchema& self) { return self.is_mutable(); });
1964 py::class_<Argument>(m, "Argument")
1965 .def_property_readonly("name", [](Argument& self) { return self.name(); })
1966 .def_property_readonly("type", [](Argument& self) { return self.type(); })
1967 .def_property_readonly(
1968 "real_type", [](Argument& self) { return self.real_type(); })
1969 .def_property_readonly(
1970 "N",
1971 [](Argument& self) -> py::object {
1972 return (self.N()) ? py::cast(*self.N()) : py::none();
1973 })
1974 .def_property_readonly(
1975 "default_value",
1976 [](Argument& self) -> py::object {
1977 if (!self.default_value()) {
1978 return py::none();
1979 }
1980 IValue v = *self.default_value();
1981 return toPyObject(std::move(v));
1982 })
1983 .def(
1984 "has_default_value",
1985 [](Argument& self) -> py::bool_ {
1986 return self.default_value().has_value();
1987 })
1988 .def_property_readonly(
1989 "alias_info", [](Argument& self) { return self.alias_info(); })
1990 .def_property_readonly(
1991 "is_out", [](Argument& self) { return self.is_out(); })
1992 .def_property_readonly("kwarg_only", [](Argument& self) -> bool {
1993 return self.kwarg_only();
1994 });
1995 py::class_<AliasInfo>(m, "_AliasInfo")
1996 .def_property_readonly(
1997 "is_write", [](AliasInfo& self) { return self.isWrite(); })
1998 .def_property_readonly(
1999 "before_set",
2000 [](AliasInfo& self) {
2001 std::set<py::str> before_set_python;
2002 for (const auto& set : self.beforeSets()) {
2003 before_set_python.insert(py::str(set.toUnqualString()));
2004 }
2005 return before_set_python;
2006 })
2007 .def_property_readonly("after_set", [](AliasInfo& self) {
2008 std::set<py::str> after_set_python;
2009 for (const auto& set : self.afterSets()) {
2010 after_set_python.insert(py::str(set.toUnqualString()));
2011 }
2012 return after_set_python;
2013 });
2014 m.def("_jit_get_all_schemas", []() {
2015 const std::vector<std::shared_ptr<Operator>>& operations =
2016 getAllOperators();
2017 return fmap(operations, [](const std::shared_ptr<Operator>& op) {
2018 return op->schema();
2019 });
2020 });
2021 m.def("_jit_get_custom_class_schemas", customClassSchemasForBCCheck);
2022 m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) {
2023 auto symbol = Symbol::fromQualString(qualified_name);
2024 const auto& operations = getAllOperatorsFor(symbol);
2025 return fmap(operations, [](const std::shared_ptr<Operator>& op) {
2026 return op->schema();
2027 });
2028 });
2029 m.def("_is_tracing", []() { return jit::tracer::isTracing(); });
2030
2031 py::class_<PythonFutureWrapper, std::shared_ptr<PythonFutureWrapper>>(
2032 m, "Future")
2033 .def(py::init([](std::vector<c10::Device> devices = {}) {
2034 return std::make_shared<PythonFutureWrapper>(
2035 c10::make_intrusive<c10::ivalue::Future>(
2036 PyObjectType::get(), std::move(devices)));
2037 }))
2038 .def(
2039 "done",
2040 // Intentionally not releasing GIL
2041 &PythonFutureWrapper::done)
2042 .def(
2043 "value",
2044 &PythonFutureWrapper::value,
2045 py::call_guard<py::gil_scoped_release>())
2046 .def(
2047 "wait",
2048 &PythonFutureWrapper::wait,
2049 py::call_guard<py::gil_scoped_release>())
2050 .def(
2051 "then",
2052 &PythonFutureWrapper::then,
2053 py::call_guard<py::gil_scoped_release>())
2054 .def(
2055 "add_done_callback",
2056 &PythonFutureWrapper::add_done_callback,
2057 py::call_guard<py::gil_scoped_release>())
2058 .def(
2059 "set_result",
2060 // Intentionally not releasing GIL
2061 &PythonFutureWrapper::markCompleted)
2062 .def(
2063 "_set_unwrap_func",
2064 // Intentionally not releasing GIL as this just does an assign
2065 [](PythonFutureWrapper& self, py::function unwrapFunc) {
2066 auto functionGuard =
2067 std::make_shared<torch::jit::PythonFunctionGuard>(
2068 std::move(unwrapFunc));
2069
2070 std::function<void(py::object)> pf =
2071 [functionGuard(std::move(functionGuard))](
2072 const py::object& inp) {
2073 return functionGuard->func_(inp);
2074 };
2075 self.unwrap_func = std::move(pf);
2076 })
2077 .def(
2078 py::pickle(
2079 /* __getstate__ */
2080 [](const PythonFutureWrapper& /* unused */) {
2081 TORCH_CHECK(false, "Can not pickle torch.futures.Future");
2082 // Note that this return has no meaning since we always
2083 // throw, it's only here to satisfy Pybind API's
2084 // requirement.
2085 return py::make_tuple();
2086 },
2087 /* __setstate__ */
2088 [](const py::tuple& /* unused */) { // NOLINT
2089 TORCH_CHECK(false, "Can not unpickle torch.futures.Future");
2090 // Note that this return has no meaning since we always
2091 // throw, it's only here to satisfy PyBind's API
2092 // requirement.
2093 return nullptr;
2094 }),
2095 py::call_guard<py::gil_scoped_release>());
2096
2097 py::class_<PythonAwaitWrapper, std::shared_ptr<PythonAwaitWrapper>>(
2098 m, "_Await")
2099 .def(
2100 "wait",
2101 &PythonAwaitWrapper::wait,
2102 py::call_guard<py::gil_scoped_release>())
2103 .def("fn", &PythonAwaitWrapper::fn)
2104 .def("args", &PythonAwaitWrapper::args)
2105 .def("type", &PythonAwaitWrapper::type)
2106 .def("is_nowait", &PythonAwaitWrapper::is_nowait)
2107 .def(
2108 "__getattr__",
2109 [](PythonAwaitWrapper& self, const std::string& name) -> py::object {
2110 // In eager mode allow Await[W] to be used as W, redirecting getattr
2111 // to the result of delayed function.
2112 return py::getattr(self.wait(), name.c_str(), py::none());
2113 })
2114 .def(
2115 py::pickle(
2116 /* __getstate__ */
2117 [](const PythonAwaitWrapper& /* unused */) {
2118 TORCH_CHECK(false, "Can not pickle torch.jit._Await");
2119 // Note that this return has no meaning since we always
2120 // throw, it's only here to satisfy Pybind API's
2121 // requirement.
2122 return py::make_tuple();
2123 },
2124 /* __setstate__ */
2125 [](const py::tuple& /* unused */) { // NOLINT
2126 TORCH_CHECK(false, "Can not unpickle torch.jit._Await");
2127 // Note that this return has no meaning since we always
2128 // throw, it's only here to satisfy PyBind's API
2129 // requirement.
2130 return nullptr;
2131 }),
2132 py::call_guard<py::gil_scoped_release>());
2133
2134 m.def("_is_alias_of", [](const py::object& self, const py::object& other) {
2135 std::optional<IValue> self_value = toTypeInferredIValueOptional(self);
2136 std::optional<IValue> other_value = toTypeInferredIValueOptional(other);
2137
2138 // Only return true if we are certain that self and other are aliasing.
2139 if (!self_value || !other_value) {
2140 return false;
2141 }
2142 return self_value->isAliasOf(*other_value);
2143 });
2144 m.def("_overlaps", [](const py::object& self, const py::object& other) {
2145 std::optional<IValue> self_value = toTypeInferredIValueOptional(self);
2146 std::optional<IValue> other_value = toTypeInferredIValueOptional(other);
2147
2148 // Only return true if we are certain that self and other are overlapping.
2149 if (!self_value || !other_value) {
2150 return false;
2151 }
2152 return self_value->overlaps(*other_value);
2153 });
2154 m.def("_awaitable", [](const py::args& args, const py::kwargs& kwargs) {
2155 AT_ASSERT(!args.empty());
2156 py::tuple args_tup(args.size() - 1);
2157 for (const auto i : c10::irange(1, args.size())) {
2158 args_tup[i - 1] = args[i];
2159 }
2160 return std::make_shared<PythonAwaitWrapper>(
2161 py::cast<py::function>(args[0]), std::move(args_tup));
2162 });
2163 m.def("_awaitable_nowait", [](py::handle input) {
2164 return std::make_shared<PythonAwaitWrapper>(input);
2165 });
2166 m.def(
2167 "_awaitable_wait", [](const std::shared_ptr<PythonAwaitWrapper>& py_aw) {
2168 TORCH_CHECK(py_aw, "Await can't be None");
2169 return py_aw->wait();
2170 });
2171 m.def("fork", [](const py::args& args, const py::kwargs& kwargs) {
2172 AT_ASSERT(!args.empty());
2173
2174 py::function f = py::cast<py::function>(args[0]);
2175 py::tuple args_tup(args.size() - 1);
2176
2177 for (const auto i : c10::irange(1, args.size())) {
2178 args_tup[i - 1] = args[i];
2179 }
2180
2181 if (jit::tracer::isTracing()) {
2182 auto graph = jit::tracer::getTracingState()->graph;
2183 auto fork_node = graph->insertNode(graph->create(prim::TracedFork, 1));
2184 auto body_block = fork_node->addBlock();
2185
2186 Value* node_output = nullptr;
2187 py::object py_func_output;
2188 // Insert new trace ops into the fork op's sub-block
2189 WithInsertPoint guard(body_block);
2190 IValue output_ivalue;
2191 {
2192 tracer::WithNestedTracingFrame env_guard;
2193
2194 // Run the user-supplied function
2195 py_func_output = f(*args_tup, **kwargs);
2196
2197 // Convert the output of the user-supplied function to IValue. The type
2198 // information of this IValue is used both to record the correct type in
2199 // the trace.
2200 output_ivalue = toTypeInferredIValue(py_func_output);
2201 Value* out_val = jit::tracer::getValueTrace(output_ivalue);
2202 body_block->registerOutput(out_val);
2203 node_output =
2204 fork_node->output()->setType(FutureType::create(out_val->type()));
2205 }
2206
2207 auto retval =
2208 c10::make_intrusive<c10::ivalue::Future>(output_ivalue.type());
2209
2210 // Record the ivalue in the tracer
2211 jit::tracer::setValueTrace(retval, node_output);
2212
2213 // stuff the ivalue output in the Future
2214 retval->markCompleted(output_ivalue);
2215
2216 return std::make_shared<PythonFutureWrapper>(retval);
2217 } else {
2218 auto result = toTypeInferredIValue(f(*args_tup, **kwargs));
2219 auto retval = c10::make_intrusive<c10::ivalue::Future>(result.type());
2220 retval->markCompleted(std::move(result));
2221 return std::make_shared<PythonFutureWrapper>(retval);
2222 }
2223 });
2224
2225 m.def("wait", [](const std::shared_ptr<PythonFutureWrapper>& fut) {
2226 TORCH_CHECK(fut, "Future can't be None");
2227 return fut->wait();
2228 });
2229
2230 m.def(
2231 "_collect_all",
2232 [](const std::vector<std::shared_ptr<jit::PythonFutureWrapper>>& futures)
2233 -> std::shared_ptr<jit::PythonFutureWrapper> {
2234 auto typePtr = futures.empty() || futures[0] == nullptr
2235 ? AnyType::get()
2236 : futures[0]->fut->elementType();
2237 c10::List<c10::intrusive_ptr<c10::ivalue::Future>> asList(
2238 c10::FutureType::create(typePtr));
2239 asList.reserve(futures.size());
2240 for (const auto& f : futures) {
2241 TORCH_CHECK(f, "Future can't be None");
2242 asList.push_back(f->fut);
2243 }
2244 return std::make_shared<jit::PythonFutureWrapper>(
2245 c10::collectAll(asList),
2246 /* unwrap_func */ [futures](const py::object& /*unused*/) {
2247 // Throw errors when calling wait() on the returned Future if
2248 // any of the original futures would throw.
2249 // NB: PythonFutureWrapper takes an unwrap_func which serves as a
2250 // callback to evalute the value in the Future. RPC uses this
2251 // unwrap_func to check whether the returned py::object is a
2252 // RemoteException object, and re-throw the exception if it is.
2253 // By extracting the c10::ivalue::Future from PythonFutureWrapper
2254 // the unwrap_func on the original PythonFutureWrapper objects are
2255 // discarded, and hence it will return the RemoteException as an
2256 // object instead of re-throwing it.
2257 for (auto& fut : futures) {
2258 fut->wait();
2259 }
2260 });
2261 },
2262 py::call_guard<py::gil_scoped_release>());
2263
2264 m.def("_jit_assert_is_instance", [](py::object obj, const TypePtr& type) {
2265 toIValue(std::move(obj), type);
2266 });
2267
2268 #if defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS)
2269 m.def("_set_print_stack_traces_on_fatal_signal", [](bool print) {
2270 c10::FatalSignalHandler::getInstance().setPrintStackTracesOnFatalSignal(
2271 print);
2272 });
2273 #endif // defined(C10_SUPPORTS_SIGNAL_HANDLER)
2274
2275 initPythonCustomClassBindings(module);
2276 initPythonIRBindings(module);
2277 tracer::initPythonTracerBindings(module);
2278 initTreeViewBindings(module);
2279 initJitScriptBindings(module);
2280 initJitBackendBindings(module);
2281 initStaticModuleBindings(module);
2282 initTensorExprBindings(module);
2283 // initNvFuserPythonBindings(module);
2284
2285 setPrintHandler([](const std::string& str) {
2286 py::gil_scoped_acquire acquire;
2287 try {
2288 auto _stdout = py::module::import("sys").attr("stdout");
2289 _stdout.attr("write")(str);
2290 } catch (py::error_already_set& e) {
2291 throw std::runtime_error(e.what());
2292 }
2293 });
2294
2295 // On exit we need to reset the print handler to default one,
2296 // because otherwise prim::Print() instruction won't work for JIT modules.
2297 auto atexit = py::module_::import("atexit");
2298 atexit.attr("register")(
2299 py::cpp_function([]() { setPrintHandler(getDefaultPrintHandler()); }));
2300 }
2301
2302 } // namespace torch::jit
2303