xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <pybind11/pytypes.h>
2 #include <torch/csrc/utils/pybind.h>
3 #include <torch/csrc/utils/python_arg_parser.h>
4 #include <torch/csrc/utils/schema_info.h>
5 
6 #include <ATen/core/operator_name.h>
7 #include <torch/csrc/jit/api/module.h>
8 #include <torch/csrc/jit/backends/backend_init.h>
9 #include <torch/csrc/jit/codegen/cuda/interface.h>
10 // #include <torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.h>
11 #include <torch/csrc/jit/codegen/fuser/interface.h>
12 #include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
13 #if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
14 #include <torch/csrc/jit/codegen/onednn/interface.h>
15 #endif
16 #include <c10/core/SymNodeImpl.h>
17 #include <torch/csrc/jit/frontend/ir_emitter.h>
18 #include <torch/csrc/jit/frontend/tracer.h>
19 #include <torch/csrc/jit/ir/irparser.h>
20 #include <torch/csrc/jit/jit_log.h>
21 #include <torch/csrc/jit/passes/autocast.h>
22 #include <torch/csrc/jit/passes/batch_mm.h>
23 #include <torch/csrc/jit/passes/canonicalize.h>
24 #include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
25 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
26 #include <torch/csrc/jit/passes/constant_pooling.h>
27 #include <torch/csrc/jit/passes/constant_propagation.h>
28 #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
29 #include <torch/csrc/jit/passes/create_functional_graphs.h>
30 #include <torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.h>
31 #include <torch/csrc/jit/passes/dead_code_elimination.h>
32 #include <torch/csrc/jit/passes/decompose_ops.h>
33 #include <torch/csrc/jit/passes/device_type_analysis.h>
34 #include <torch/csrc/jit/passes/dtype_analysis.h>
35 #include <torch/csrc/jit/passes/erase_number_types.h>
36 #include <torch/csrc/jit/passes/fold_conv_bn.h>
37 #include <torch/csrc/jit/passes/freeze_module.h>
38 #include <torch/csrc/jit/passes/frozen_concat_linear.h>
39 #include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
40 #include <torch/csrc/jit/passes/frozen_conv_folding.h>
41 #include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
42 #include <torch/csrc/jit/passes/frozen_linear_folding.h>
43 #include <torch/csrc/jit/passes/frozen_linear_transpose.h>
44 #include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
45 #include <torch/csrc/jit/passes/fuse_linear.h>
46 #include <torch/csrc/jit/passes/fuse_relu.h>
47 #include <torch/csrc/jit/passes/graph_fuser.h>
48 #include <torch/csrc/jit/passes/inline_fork_wait.h>
49 #include <torch/csrc/jit/passes/inliner.h>
50 #include <torch/csrc/jit/passes/integer_value_refinement.h>
51 #include <torch/csrc/jit/passes/loop_unrolling.h>
52 #include <torch/csrc/jit/passes/lower_graph.h>
53 #include <torch/csrc/jit/passes/lower_tuples.h>
54 #include <torch/csrc/jit/passes/metal_rewrite.h>
55 #include <torch/csrc/jit/passes/mobile_optimizer_type.h>
56 #include <torch/csrc/jit/passes/normalize_ops.h>
57 #include <torch/csrc/jit/passes/peephole.h>
58 #include <torch/csrc/jit/passes/peephole_list_idioms.h>
59 #include <torch/csrc/jit/passes/quantization/dedup_module_uses.h>
60 #include <torch/csrc/jit/passes/quantization/finalize.h>
61 #include <torch/csrc/jit/passes/quantization/fusion_passes.h>
62 #include <torch/csrc/jit/passes/quantization/insert_observers.h>
63 #include <torch/csrc/jit/passes/quantization/insert_quant_dequant.h>
64 #include <torch/csrc/jit/passes/quantization/quantization_type.h>
65 #include <torch/csrc/jit/passes/refine_tuple_types.h>
66 #include <torch/csrc/jit/passes/remove_dropout.h>
67 #include <torch/csrc/jit/passes/remove_expands.h>
68 #include <torch/csrc/jit/passes/remove_inplace_ops.h>
69 #include <torch/csrc/jit/passes/remove_mutation.h>
70 #include <torch/csrc/jit/passes/replacement_of_old_operators.h>
71 #include <torch/csrc/jit/passes/restore_mutation.h>
72 #include <torch/csrc/jit/passes/shape_analysis.h>
73 #include <torch/csrc/jit/passes/specialize_autogradzero.h>
74 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
75 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
76 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
77 #include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
78 #include <torch/csrc/jit/passes/vulkan_rewrite.h>
79 #include <torch/csrc/jit/passes/xnnpack_rewrite.h>
80 #include <torch/csrc/jit/python/pybind_utils.h>
81 #include <torch/csrc/jit/python/python_arg_flatten.h>
82 #include <torch/csrc/jit/python/python_custom_class.h>
83 #include <torch/csrc/jit/python/python_ir.h>
84 #include <torch/csrc/jit/python/python_tracer.h>
85 #include <torch/csrc/jit/python/python_tree_views.h>
86 #include <torch/csrc/jit/python/script_init.h>
87 #include <torch/csrc/jit/python/utf8_decoding_ignore.h>
88 #include <torch/csrc/jit/runtime/argument_spec.h>
89 #include <torch/csrc/jit/runtime/autodiff.h>
90 #include <torch/csrc/jit/runtime/decomposition_registry.h>
91 #include <torch/csrc/jit/runtime/graph_executor.h>
92 #include <torch/csrc/jit/runtime/jit_exception.h>
93 #include <torch/csrc/jit/runtime/jit_trace.h>
94 #include <torch/csrc/jit/runtime/operator.h>
95 #include <torch/csrc/jit/runtime/print_handler.h>
96 #include <torch/csrc/jit/runtime/static/init.h>
97 #include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
98 #include <torch/csrc/jit/serialization/export.h>
99 #include <torch/csrc/jit/serialization/import.h>
100 #include <torch/csrc/jit/tensorexpr/kernel.h>
101 #include <torch/csrc/jit/tensorexpr/tensorexpr_init.h>
102 #include <torch/csrc/utils/cpp_stacktraces.h>
103 
104 #include <c10/macros/Export.h>
105 #include <c10/util/irange.h>
106 #include <c10/util/signal_handler.h>
107 #include <caffe2/serialize/inline_container.h>
108 
109 #include <pybind11/cast.h>
110 #include <pybind11/functional.h>
111 #include <pybind11/iostream.h>
112 #include <pybind11/operators.h>
113 
114 #include <torch/csrc/jit/runtime/profiling_graph_executor_impl.h>
115 #include <memory>
116 #include <sstream>
117 #include <stdexcept>
118 #include <string>
119 #include <tuple>
120 #include <utility>
121 
122 namespace torch::jit {
123 
124 using c10::AliasInfo;
125 using c10::Argument;
126 using c10::FunctionSchema;
127 using c10::SchemaArgType;
128 using c10::SchemaArgument;
129 using c10::SymNode;
130 using caffe2::serialize::PyTorchStreamReader;
131 using caffe2::serialize::PyTorchStreamWriter;
132 using torch::utils::SchemaInfo;
133 
134 namespace {
135 
136 using autograd::variable_list;
137 
loadPythonClasses()138 bool loadPythonClasses() {
139   // Leaving this code here, because it will likely be useful at some point
140   // PyObject *jit_module = PyImport_ImportModule("torch.jit");
141   // THPUtils_assert(jit_module, "class loader couldn't access "
142   //"torch.jit module");
143   // PyObject *jit_dict = PyModule_GetDict(jit_module);
144 
145   return true;
146 }
147 
opAllowsNumbersAsTensors(c10::Symbol symbol)148 static bool opAllowsNumbersAsTensors(c10::Symbol symbol) {
149   return symbol.is_prims() || symbol.is_nvprims() ||
150       (symbol.is_aten() &&
151        torch::should_allow_numbers_as_tensors(symbol.toUnqualString()));
152 }
153 
toTypeInferredIValueOptional(py::handle input)154 std::optional<IValue> toTypeInferredIValueOptional(py::handle input) {
155   // Errors need to be caught here because toTypeInferredIValue errors out
156   // on various object types, but we want it to work with all types.
157   try {
158     return toTypeInferredIValue(input);
159   } catch (const c10::Error& e) {
160     return std::nullopt;
161   }
162 }
163 } // anonymous namespace
164 
165 #if !defined(USE_ROCM)
166 TORCH_API void runJITCPPTests();
167 #endif
168 
initJITBindings(PyObject * module)169 void initJITBindings(PyObject* module) {
170   auto m = py::handle(module).cast<py::module>();
171   auto jit = m.def_submodule("_jit");
172 
173   // This is a static object, so we must leak the Python object
174   // "release()" is used here to preserve 1 refcount on the
175   // object, preventing it from ever being de-allocated by CPython.
176   static py::handle exc =
177       py::exception<JITException>(m, "JITException").release();
178 
179   py::register_exception_translator([](std::exception_ptr p) {
180     try {
181       if (p) {
182         std::rethrow_exception(p);
183       }
184     } catch (const JITException& e) {
185       // special handling of JITException, to set its python class name and msg
186       py::gil_scoped_acquire acquire;
187       const auto& className = e.getPythonClassName();
188       const auto& originalMsg = e.getOriginalMsg();
189       JITException::setCaughtOriginalMsg(originalMsg.value_or(""));
190       JITException::setCaughtPythonClassName(className.value_or(""));
191       // If we still had the py::exception<JITException> object, we could
192       // just call it. But we must get a handle to leak it and there is no
193       // way I can find to re-create it from the handle. So setting the
194       // exception manually
195       PyErr_SetString(exc.ptr(), e.what());
196     }
197   });
198 
199   m.def(
200       "_get_caught_jit_exception_class_name",
201       JITException::getCaughtPythonClassName);
202   m.def(
203       "_get_caught_jit_exception_original_msg",
204       JITException::getCaughtOriginalMsg);
205 
206   py::class_<python::IODescriptor> iodescriptor(
207       m,
208       "IODescriptor"); // NOLINT(bugprone-unused-raii)
209 
210   m.def("_jit_init", loadPythonClasses)
211       .def(
212           "_jit_debug_fuser_num_cached_kernel_specs",
213           torch::jit::fuser::debugNumCachedKernelSpecs)
214       .def("_jit_pass_lower_all_tuples", LowerAllTuples)
215       .def(
216           "_new_symbolic_shape_symbol",
217           []() { return c10::ShapeSymbol::newSymbol().value(); })
218       .def(
219           "_jit_shape_compute_graph_for_node",
220           [](Node* n) -> std::optional<std::shared_ptr<Graph>> {
221             if (!n->maybeSchema()) {
222               return std::nullopt;
223             }
224             return shapeComputeGraphForSchema(n->schema());
225           })
226       .def(
227           "_jit_decomposition_graph_for_node",
228           [](Node* n) -> std::optional<std::shared_ptr<Graph>> {
229             if (!n->maybeSchema()) {
230               return std::nullopt;
231             }
232             return GetDecomposition(n->schema());
233           })
234       .def("_jit_pass_run_decompositions", RunDecompositions)
235       // using Node* here instead of Schema because looking up the schema
236       // and passing it in from Python will have a different pointer than the
237       // schema that is globally used for caching
238       .def(
239           "_jit_register_shape_compute_graph_for_node",
240           [](Node* n, std::shared_ptr<Graph>& graph) {
241             if (n->maybeSchema()) {
242               const FunctionSchema& schema = n->schema();
243               RegisterShapeComputeGraphForSchema(schema, graph);
244             } else {
245               TORCH_INTERNAL_ASSERT(false, "Expected schema", n);
246             }
247           })
248       .def(
249           "_jit_register_decomposition_for_schema",
250           [](const FunctionSchema& s, std::shared_ptr<Graph>& graph) {
251             // because this is invoked by python, the function schema *
252             // becomes different, and we need to find and reuse the
253             // one that is used for caching
254             auto op =
255                 findOperatorFor(c10::OperatorName(s.name(), s.overload_name()));
256             RegisterDecomposition(op->schema(), graph);
257           })
258       .def("_jit_pass_propagate_shapes_on_graph", PropagateShapesOnGraph)
259       .def(
260           "_jit_pass_propagate_shapes_on_graph_and_build_compute",
261           [](std::shared_ptr<Graph>& graph) {
262             return PropagateShapesAndBuildLargeShapeComputeGraph(
263                 graph, *graph->nodes().begin(), *graph->nodes().end());
264           })
265       .def(
266           "_jit_pass_propagate_shapes_on_graph_and_build_compute",
267           [](std::shared_ptr<Graph>& graph, Node* beg) {
268             return PropagateShapesAndBuildLargeShapeComputeGraph(
269                 graph, beg, *graph->nodes().end());
270           })
271       .def(
272           "_jit_pass_propagate_shapes_on_graph_and_build_compute",
273           PropagateShapesAndBuildLargeShapeComputeGraph)
274       .def("_jit_pass_integer_value_refinement", RefineIntegerValues)
275       .def(
276           "_jit_set_symbolic_shapes_test_mode",
277           &setSymbolicShapeAnalysisTestMode)
278       .def(
279           "_jit_symbolic_shapes_test_mode_enabled",
280           &symbolicShapeAnalysisTestModeEnabled)
281       .def("_jit_pass_autocast", Autocast)
282       .def("_jit_set_autocast_mode", &setAutocastMode)
283       .def("_jit_pass_fuse", FuseGraph)
284       .def(
285           "_jit_pass_replace_old_ops_with_upgraders",
286           [](std::shared_ptr<Graph>& g) {
287             return ReplaceOldOperatorsWithUpgraders(g);
288           })
289       .def(
290           "_jit_pass_dce",
291           [](std::shared_ptr<Graph>& g) {
292             return EliminateDeadCode(g->block()); // overload resolution
293           })
294       .def(
295           "_jit_pass_dce_allow_deleting_nodes_with_side_effects",
296           [](std::shared_ptr<Graph>& g) {
297             return EliminateDeadCode(
298                 g->block(),
299                 true,
300                 DCESideEffectPolicy::
301                     ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); // overload
302                                                              // resolution
303           })
304       .def(
305           "_jit_pass_cse",
306           [](std::shared_ptr<Graph>& g) {
307             return EliminateCommonSubexpression(g); // overload resolution
308           })
309       .def(
310           "_jit_pass_fuse_quantized_add_relu",
311           [](std::shared_ptr<Graph>& g) {
312             return FuseQuantizedAddRelu(g); // overload resolution
313           })
314       .def(
315           "_jit_pass_insert_observers",
316           [](Module& module,
317              const std::string& method_name,
318              const py::dict& qconfig_dict,
319              bool inplace,
320              int quant_type_int) {
321             auto dict = py::cast<std::unordered_map<
322                 std::string,
323                 std::optional<std::tuple<Module, Module>>>>(qconfig_dict);
324             auto quant_type = static_cast<QuantType>(quant_type_int);
325             return InsertObservers(
326                 module, method_name, dict, inplace, quant_type);
327           },
328           py::arg("module"),
329           py::arg("method_name"),
330           py::arg("qconfig_dict"),
331           py::arg("inplace"),
332           py::arg("quant_type_int") = 1)
333       .def(
334           "_jit_pass_insert_observer_method_for_ondevice_ptq",
335           [](Module& module,
336              const std::string& method_name,
337              const py::dict& qconfig_dict,
338              bool inplace,
339              int quant_type_int) {
340             auto dict = py::cast<std::unordered_map<
341                 std::string,
342                 std::optional<std::tuple<Module, Module>>>>(qconfig_dict);
343             auto quant_type = static_cast<QuantType>(quant_type_int);
344             return InsertObserversForOnDevicePTQ(
345                 module, method_name, dict, inplace, quant_type);
346           },
347           py::arg("module"),
348           py::arg("method_name"),
349           py::arg("qconfig_dict"),
350           py::arg("inplace"),
351           py::arg("quant_type_int") = 1)
352       .def(
353           "_jit_pass_insert_quant_dequant",
354           [](Module& module,
355              const std::string& method_name,
356              bool inplace,
357              bool debug,
358              int quant_type_int) {
359             auto quant_type = static_cast<QuantType>(quant_type_int);
360             return InsertQuantDeQuant(
361                 module, method_name, inplace, debug, quant_type);
362           },
363           py::arg("module"),
364           py::arg("method_name"),
365           py::arg("inplace"),
366           py::arg("debug"),
367           py::arg("quant_type_int") = 1)
368       .def(
369           "_jit_pass_insert_quant_dequant_for_ondevice_ptq",
370           [](Module& module,
371              const std::string& method_name,
372              bool inplace,
373              bool debug,
374              int quant_type_int) {
375             auto quant_type = static_cast<QuantType>(quant_type_int);
376             return InsertQuantDeQuantOnDevicePTQ(
377                 module, method_name, inplace, debug, quant_type);
378           },
379           py::arg("module"),
380           py::arg("method_name"),
381           py::arg("inplace"),
382           py::arg("debug"),
383           py::arg("quant_type_int") = 1)
384       .def(
385           "_jit_pass_insert_prepack_unpack",
386           [](std::shared_ptr<Graph>& g) { return InsertPrepackUnpack(g); })
387       .def(
388           "_jit_pass_insert_prepack_unpack",
389           [](Module& module) { return InsertPrepackUnpack(module); })
390       .def(
391           "_jit_pass_quant_fusion",
392           [](std::shared_ptr<Graph>& g) { return QuantFusion(g); })
393       .def(
394           "_jit_pass_fold_convbn",
395           [](Module& module) { return FoldConvBatchNorm(module); })
396       .def(
397           "_jit_pass_dbr_quant_remove_redundant_aliases",
398           [](Module& module) { return DBRQuantRemoveRedundantAliases(module); })
399       .def(
400           "_freeze_module",
401           [](Module& module,
402              std::vector<std::string>& preservedAttrs,
403              bool freezeInterfaces,
404              bool preserveParameters) {
405             return freeze_module(
406                 module, preservedAttrs, freezeInterfaces, preserveParameters);
407           },
408           py::arg("module"),
409           py::arg("preservedAttrs") = std::vector<std::string>(),
410           py::arg("freezeInterfaces") = true,
411           py::arg("preserveParameters") = false)
412       .def("_jit_pass_concat_frozen_linear", &FrozenConcatLinear)
413       .def("_jit_pass_fold_frozen_conv_bn", &FoldFrozenConvBatchnorm)
414       .def("_jit_pass_fold_frozen_conv_add_or_sub", &FoldFrozenConvAddOrSub)
415       .def("_jit_pass_fold_frozen_conv_mul_or_div", &FoldFrozenConvMulOrDiv)
416       .def("_jit_pass_fold_frozen_linear_bn", &FoldFrozenLinearBatchnorm)
417       .def("_jit_pass_convert_frozen_ops_to_mkldnn", &ConvertFrozenOpsToMKLDNN)
418       .def("_jit_pass_fuse_frozen_conv_add_relu", &FuseFrozenConvAddRelu)
419       .def("_jit_pass_transpose_frozen_linear", &FrozenLinearTranspose)
420       .def("_jit_pass_optimize_frozen_graph", &OptimizeFrozenGraph)
421       .def(
422           "_jit_pass_optimize_for_inference",
423           [](Module& module, const std::vector<std::string>& other_methods) {
424             optimize_for_inference(module, other_methods);
425           },
426           py::arg("module"),
427           py::arg("other_methods") = std::vector<std::string>())
428       .def("_jit_pass_fuse_linear", &FuseLinear)
429       .def(
430           "_jit_pass_fuse_add_relu",
431           [](std::shared_ptr<Graph>& graph) { FuseAddRelu(graph); })
432       .def("_jit_pass_dedup_module_uses", &DedupModuleUses)
433       .def("_jit_pass_replicate_dequantize", &ReplicateDeQuant)
434       .def(
435           "_jit_pass_swap_functional_linear",
436           [](std::shared_ptr<Graph>& graph) { SwapFunctionalLinear(graph); })
437       .def(
438           "_jit_pass_swap_functional_linear",
439           [](Module& module) { SwapFunctionalLinear(module); })
440       .def(
441           "_jit_pass_quant_finalize",
442           [](Module& module,
443              int quant_type_int,
444              const std::vector<std::string>& preserved_attrs) {
445             auto quant_type = static_cast<QuantType>(quant_type_int);
446             return Finalize(module, quant_type, preserved_attrs);
447           },
448           py::arg("module"),
449           py::arg("quant_type_int") = 1,
450           py::arg("preserved_attrs") = std::vector<std::string>())
451       .def(
452           "_jit_pass_quant_finalize_for_ondevice_ptq",
453           [](Module& module,
454              int quant_type_int,
455              const std::string& method_name) {
456             auto quant_type = static_cast<QuantType>(quant_type_int);
457             return FinalizeOnDevicePTQ(module, quant_type, method_name);
458           },
459           py::arg("module"),
460           py::arg("quant_type_int") = 1,
461           py::arg("preserved_attrs") = std::vector<std::string>())
462       .def(
463           "_jit_pass_pattern_based_rewrite",
464           [](const Module& m) { return PatternBasedRewrite(m); })
465       .def(
466           "_jit_pass_custom_pattern_based_rewrite",
467           [](const std::string& pattern,
468              const std::string& fused_node_name,
469              const Module& m) {
470             SubgraphRewriter subgraph_rewriter;
471             subgraph_rewriter.RegisterRewritePattern(pattern, fused_node_name);
472             subgraph_rewriter.runOnModule(m);
473           })
474       .def(
475           "_jit_pass_custom_pattern_based_rewrite_graph",
476           [](const std::string& pattern,
477              const std::string& fused_node_name,
478              std::shared_ptr<Graph> g,
479              const std::vector<std::pair<std::string, std::string>>&
480                  value_name_pairs) {
481             SubgraphRewriter subgraph_rewriter;
482             subgraph_rewriter.RegisterRewritePattern(
483                 pattern, fused_node_name, value_name_pairs);
484             subgraph_rewriter.runOnGraph(g);
485           },
486           py::arg("pattern"),
487           py::arg("fused_node_name"),
488           py::arg("g"),
489           py::arg("value_name_pairs") =
490               std::vector<std::pair<std::string, std::string>>())
491       .def("_jit_pass_constant_pooling", ConstantPooling)
492       // RemoveInplaceOps is used by CoreML so it must be removed with care.
493       .def("_jit_pass_propagate_dtype", DtypePropagation)
494       .def("_jit_pass_propagate_device", DeviceTypePropagation)
495       .def(
496           "_jit_pass_remove_inplace_ops",
497           [](const std::shared_ptr<Graph>& g) { return RemoveInplaceOps(g); })
498       .def(
499           "_jit_pass_create_functional_graphs",
500           [](std::shared_ptr<Graph>& g) { return CreateFunctionalGraphs(g); })
501       .def(
502           "_jit_pass_remove_mutation",
503           [](std::shared_ptr<Graph>& g) {
504             RemoveListMutation(g);
505             return RemoveTensorMutation(g);
506           })
507       .def(
508           "_jit_pass_functional_to_inplace_activation",
509           [](std::shared_ptr<Graph>& g) {
510             return FunctionalToInplaceActivation(g);
511           })
512       .def(
513           "_jit_pass_inplace_to_functional_activation",
514           [](std::shared_ptr<Graph>& g) {
515             return InplaceToFunctionalActivation(g);
516           })
517       .def(
518           "_jit_pass_inline_functional_graphs",
519           [](std::shared_ptr<Graph>& g) { return InlineFunctionalGraphs(g); })
520       .def(
521           "_jit_pass_peephole",
522           [](const std::shared_ptr<Graph>& g, bool disable_shape_peepholes) {
523             return PeepholeOptimize(g, disable_shape_peepholes);
524           },
525           py::arg("graph"),
526           py::arg("disable_shape_peepholes") = false)
527       .def(
528           "_jit_pass_peephole_list_idioms",
529           [](const std::shared_ptr<Graph>& g, bool refine_list_len) {
530             return PeepholeOptimizeListIdioms(g, refine_list_len);
531           },
532           py::arg("graph"),
533           py::arg("refine_list_len") = false)
534       .def(
535           "_jit_pass_refine_integer_values",
536           [](std::shared_ptr<Graph>& g) { return RefineIntegerValues(g); })
537       .def(
538           "_jit_pass_fuse_addmm",
539           [](std::shared_ptr<Graph>& g) { return FuseAddMM(g); })
540       .def(
541           "_jit_pass_canonicalize",
542           [](const std::shared_ptr<Graph>& g, bool keep_unique_names = true) {
543             return Canonicalize(g, keep_unique_names);
544           },
545           py::arg("graph"),
546           py::arg("keep_unique_names") = true)
547       .def("_jit_pass_lint", LintGraph)
548       .def(
549           "_jit_pass_complete_shape_analysis",
550           [](const std::shared_ptr<Graph>& graph,
551              const py::tuple& inputs,
552              bool with_grad) {
553             ArgumentSpecCreator arg_spec_creator(*graph);
554             Stack stack;
555             stack.reserve(inputs.size()); // captures?
556             for (auto& obj : inputs) {
557               stack.push_back(toTypeInferredIValue(obj));
558             }
559             ArgumentSpec spec = arg_spec_creator.create(with_grad, stack);
560             arg_spec_creator.specializeTypes(*graph, spec);
561             // We only get partial specialization from the arg_spec_creator, but
562             // we want full shape specialization. The alternative would be to
563             // have a "complete type inference" function in ArguemntSpecCreator.
564             auto g_inputs = graph->inputs();
565             for (const auto i : c10::irange(inputs.size())) {
566               if (stack[i].isTensor()) {
567                 g_inputs[i]->setType(stack[i].type());
568               }
569             }
570             PropagateInputShapes(graph);
571           })
572       .def(
573           "_jit_interpret_graph",
574           [](std::shared_ptr<Graph>& graph, const py::tuple& inputs) {
575             Stack stack;
576             stack.reserve(inputs.size()); // captures?
577             for (auto& obj : inputs) {
578               stack.push_back(toTypeInferredIValue(obj));
579             }
580             auto g_inputs = graph->inputs();
581             for (const auto i : c10::irange(inputs.size())) {
582               if (stack[i].isTensor()) {
583                 g_inputs[i]->setType(stack[i].type());
584               }
585             }
586             Code code(graph, "<on-demand-func>");
587             InterpreterState(code).run(stack);
588             return createPyObjectForStack(std::move(stack));
589           },
590           py::doc(
591               "Interpret a JIT graph with given inputs without running any optimization passes on it"))
592       .def(
593           "_jit_trace_graph",
594           [](std::shared_ptr<Graph>& graph, const py::tuple& inputs) {
595             Stack stack;
596             stack.reserve(inputs.size()); // captures?
597             for (auto& obj : inputs) {
598               stack.push_back(toTypeInferredIValue(obj));
599             }
600             auto g_inputs = graph->inputs();
601             for (const auto i : c10::irange(inputs.size())) {
602               if (stack[i].isTensor()) {
603                 g_inputs[i]->setType(stack[i].type());
604               }
605             }
606             return TraceGraph(graph, stack);
607           })
608       .def(
609           "_jit_trace_module",
610           [](Module& model, const py::tuple& inputs) {
611             auto graph = model.get_method("forward").graph();
612             Stack stack;
613             stack.reserve(inputs.size() + 1); // captures?
614             push(stack, model._ivalue());
615             for (auto& obj : inputs) {
616               stack.push_back(toTypeInferredIValue(obj));
617             }
618             auto traced = TraceGraph(graph, stack);
619             GRAPH_DUMP("Traced Graph", traced);
620 
621             // the easiest way to replace a graph in a module is
622             // to remove all the nodes in the original graph
623             // clone everything from the traced one
624             graph->block()->clear();
625             graph->block()->cloneFrom(traced->block(), nullptr);
626             GRAPH_DUMP("Copied Graph", graph);
627           })
628       .def("_jit_pass_remove_expands", RemoveExpands)
629       .def("_jit_pass_erase_number_types", EraseNumberTypes)
630       .def("_jit_pass_inline_fork_wait", InlineForkWait)
631       .def("_jit_pass_inline", Inline)
632       .def(
633           "_jit_pass_lower_graph",
634           [](std::shared_ptr<Graph>& graph, const Module& self) {
635             return LowerGraph(*graph, self._ivalue());
636           })
637       .def("_jit_pass_loop_unrolling", UnrollLoops)
638       .def("_jit_pass_constant_loop_unrolling", UnrollConstantLoops)
639       .def(
640           "_jit_pass_constant_propagation_immutable_types",
641           [](std::shared_ptr<Graph>& g) {
642             return ConstantPropagationImmutableTypes(g);
643           })
644       .def(
645           "_jit_pass_constant_propagation",
646           [](std::shared_ptr<Graph>& g) { return ConstantPropagation(g); },
647           py::arg("graph"))
648       .def("_jit_pass_erase_shape_information", EraseShapeInformation)
649       .def(
650           "_jit_object_is_non_holding",
651           [](Node& n) {
652             return toIValue(n.output())->toObject()->is_weak_compilation_ref();
653           })
654       .def(
655           "_jit_erase_non_input_shape_information",
656           [](std::shared_ptr<Graph>& g) {
657             std::vector<TypePtr> input_types;
658             for (Value* v : g->inputs()) {
659               if (auto tt = v->type()->cast<TensorType>()) {
660                 input_types.emplace_back(tt);
661               } else {
662                 input_types.emplace_back(nullptr);
663               }
664             }
665             EraseShapeInformation(g);
666             for (size_t i = 0; i < input_types.size(); ++i) {
667               if (input_types[i]) {
668                 g->inputs().at(i)->setType(input_types[i]);
669               }
670             }
671           })
672       .def(
673           "_jit_pass_create_autodiff_subgraphs",
674           [](const std::shared_ptr<Graph>& graph, const py::object& threshold) {
675             if (threshold.is_none()) {
676               CreateAutodiffSubgraphs(graph);
677             } else {
678               CreateAutodiffSubgraphs(graph, py::cast<int>(threshold));
679             }
680           },
681           py::arg("graph"),
682           py::arg("threshold") = py::none())
683 #if defined(BUILDING_TESTS) && !defined(USE_ROCM)
684       .def(
685           "_jit_run_cpp_tests",
686           []() {
687             // We have to release the GIL inside this method, because if we
688             // happen to initialize the autograd engine in these tests, the
689             // newly spawned worker threads will try to initialize their
690             // PyThreadState*, and they need the GIL for this.
691             pybind11::gil_scoped_release _no_gil;
692             return runJITCPPTests();
693           })
694       .def("_jit_has_cpp_tests", []() { return true; })
695       .def("_has_tensorexpr_cpp_tests", []() { return true; })
696 #else
697       .def("_jit_run_cpp_tests", []() { throw std::exception(); })
698       .def("_jit_has_cpp_tests", []() { return false; })
699       .def("_run_tensorexpr_cpp_tests", []() { throw std::exception(); })
700       .def("_has_tensorexpr_cpp_tests", []() { return false; })
701 #endif
702       .def(
703           "_jit_flatten",
704           [](py::handle& obj) {
705             auto res = python::flatten(obj);
706             return std::make_pair(res.vars, res.desc);
707           })
708       .def(
709           "_jit_unflatten",
710           [](const autograd::variable_list& vars, python::IODescriptor& desc) {
711             return py::reinterpret_steal<py::object>(
712                 python::unflatten(vars, desc));
713           })
714       .def("_jit_pass_canonicalize_graph_fuser_ops", CanonicalizeOps)
715       .def("_jit_pass_decompose_ops", DecomposeOps)
716       .def("_jit_pass_specialize_autogradzero", specializeAutogradZero)
717       .def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
718       .def("_jit_override_can_fuse_on_gpu", &overrideCanFuseOnGPU)
719       .def("_jit_can_fuse_on_cpu", &canFuseOnCPU)
720       .def("_jit_can_fuse_on_gpu", &canFuseOnGPU)
721       .def("_jit_can_fuse_on_cpu_legacy", &canFuseOnCPULegacy)
722       .def("_jit_override_can_fuse_on_cpu_legacy", &overrideCanFuseOnCPULegacy)
723       .def(
724           "_jit_differentiate",
725           [](Graph& g) {
726             // the python binding slightly differs in semantics
727             // it makes a copy of the input Graph, and works on that
728             // jit::differentiate mutates the input Graph
729             auto g_clone = g.copy();
730             return differentiate(g_clone);
731           })
732       .def(
733           "_jit_check_alias_annotation",
734           [](const std::shared_ptr<Graph>& g,
735              const py::tuple& args,
736              const std::string& unqualified_op_name) {
737             auto stack = toTraceableStack(args);
738             checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
739           })
740 #if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
741       .def("_jit_set_llga_enabled", &RegisterLlgaFuseGraph::setEnabled)
742       .def("_jit_llga_enabled", &RegisterLlgaFuseGraph::isEnabled)
743 #else
744       .def("_jit_set_llga_enabled", [](bool flag) { return false; })
745       .def("_jit_llga_enabled", []() { return false; })
746 #endif
747       .def(
748           "_jit_set_tracer_state_warn",
749           [](bool new_warn) {
750             jit::tracer::getTracerStateWarnMode() = new_warn;
751           })
752       .def(
753           "_jit_get_tracer_state_warn",
754           []() {
755             bool current_tracer_warn = jit::tracer::getTracerStateWarnMode();
756             return current_tracer_warn;
757           })
758       .def(
759           "_jit_set_nvfuser_skip_node_kind",
760           [](const std::string& op_name, bool flip = true) {
761             TORCH_WARN(
762                 "nvfuser is no longer supported in torch script, use _jit_set_nvfuser_skip_node_kind is deprecated and a no-op");
763           })
764       .def(
765           "_jit_set_nvfuser_enabled",
766           [](bool) {
767             TORCH_WARN(
768                 "nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op");
769           })
770       .def(
771           "_jit_nvfuser_can_be_enabled",
772           []() {
773             TORCH_WARN(
774                 "nvfuser is no longer supported in torch script, use _jit_nvfuser_can_be_enabled is deprecated and a no-op");
775           })
776       .def(
777           "_jit_set_nvfuser_single_node_mode",
778           [](bool) {
779             TORCH_WARN(
780                 "nvfuser is no longer supported in torch script, use _jit_set_nvfuser_single_node_mode is deprecated and a no-op");
781           })
782       .def(
783           "_jit_nvfuser_single_node_mode",
784           []() {
785             TORCH_WARN(
786                 "nvfuser is no longer supported in torch script, use _jit_nvfuser_single_node_mode is deprecated and a no-op");
787           })
788       .def(
789           "_jit_set_nvfuser_horizontal_mode",
790           [](bool) {
791             TORCH_WARN(
792                 "nvfuser is no longer supported in torch script, use _jit_set_nvfuser_horizontal_mode is deprecated and a no-op");
793           })
794       .def(
795           "_jit_nvfuser_horizontal_mode",
796           []() {
797             TORCH_WARN(
798                 "nvfuser is no longer supported in torch script, use _jit_nvfuser_horizontal_mode is deprecated and a no-op");
799           })
800       .def(
801           "_jit_set_nvfuser_guard_mode",
802           [](bool) {
803             TORCH_WARN(
804                 "nvfuser is no longer supported in torch script, use _jit_set_nvfuser_guard_mode is deprecated and a no-op");
805           })
806       .def(
807           "_jit_nvfuser_enabled",
808           []() {
809             TORCH_WARN(
810                 "nvfuser is no longer supported in torch script, use _jit_nvfuser_enabled is deprecated and a no-op");
811           })
812       .def(
813           "_jit_nvfuser_set_comparison_callback",
814           [](bool, py::function) {
815             TORCH_WARN(
816                 "nvfuser is no longer supported in torch script, use _jit_nvfuser_set_comparison_callback is deprecated and a no-op");
817           })
818       .def(
819           "_jit_nvfuser_clear_comparison_callback",
820           []() {
821             TORCH_WARN(
822                 "nvfuser is no longer supported in torch script, use _jit_nvfuser_clear_comparison_callback is deprecated and a no-op");
823           })
824       .def(
825           "_jit_set_profiling_mode",
826           [](bool profiling_flag) {
827             bool oldState = getProfilingMode();
828             getProfilingMode() = profiling_flag;
829             return oldState;
830           })
831       .def(
832           "_jit_set_profiling_executor",
833           [](bool profiling_flag) {
834             bool oldState = getExecutorMode();
835             getExecutorMode() = profiling_flag;
836             return oldState;
837           })
838       .def(
839           "_jit_set_num_profiled_runs",
840           [](size_t num) {
841             size_t old_num = getNumProfiledRuns();
842             getNumProfiledRuns() = num;
843             return old_num;
844           })
845       .def(
846           "_jit_get_num_profiled_runs",
847           [] {
848             // pybind can't automatically bind to atomic size_t
849             size_t num_runs = getNumProfiledRuns();
850             return num_runs;
851           })
852       .def(
853           "_jit_set_bailout_depth",
854           [](size_t depth) {
855             TORCH_WARN(
856                 "Use _jit_set_fusion_strategy, bailout depth is deprecated. Setting to (STATIC, ",
857                 depth,
858                 ")");
859             size_t old_depth = getBailoutDepth();
860             FusionStrategy strat = {{FusionBehavior::STATIC, depth}};
861             setFusionStrategy(strat);
862             return old_depth;
863           })
864       .def(
865           "_jit_set_fusion_strategy",
866           [](const std::vector<std::pair<std::string, size_t>>& strategy) {
867             FusionStrategy vec_conv;
868             for (const auto& pair : strategy) {
869               if (pair.first == "STATIC") {
870                 vec_conv.emplace_back(FusionBehavior::STATIC, pair.second);
871               } else if (pair.first == "DYNAMIC") {
872                 vec_conv.emplace_back(FusionBehavior::DYNAMIC, pair.second);
873               } else {
874                 TORCH_INTERNAL_ASSERT(
875                     false,
876                     "FusionBehavior only supported 'STATIC' or 'DYNAMIC', got: ",
877                     pair.first);
878               }
879             }
880             auto old_strategy = getFusionStrategy();
881             auto strat =
882                 fmap(old_strategy, [](std::pair<FusionBehavior, size_t> behav) {
883                   return std::pair<std::string, size_t>(
884                       behav.first == FusionBehavior::STATIC ? "STATIC"
885                                                             : "DYNAMIC",
886                       behav.second);
887                 });
888             setFusionStrategy(vec_conv);
889             return strat;
890           })
891       .def(
892           "_jit_set_inline_everything_mode",
893           [](bool enabled) { getInlineEverythingMode() = enabled; })
894       .def(
895           "_jit_get_inline_everything_mode",
896           []() { return getInlineEverythingMode(); })
897       .def(
898           "_jit_get_logging_option",
899           []() { return ::torch::jit::get_jit_logging_levels(); })
900       .def(
901           "_jit_set_logging_option",
902           [](std::string loggingOption) -> void {
903             ::torch::jit::set_jit_logging_levels(std::move(loggingOption));
904           })
905       .def(
906           "_jit_set_logging_stream",
907           [](const std::string& stream_name) -> void {
908             if (stream_name == "stdout") {
909               ::torch::jit::set_jit_logging_output_stream(std::cout);
910             } else if (stream_name == "stderr") {
911               ::torch::jit::set_jit_logging_output_stream(std::cerr);
912             } else {
913               std::cerr << "ERROR: only `stdout` and `stderr`"
914                         << "are supported as output options" << '\n';
915             }
916           })
917       .def(
918           "_storage_id",
919           [](const at::Tensor& ten) -> int64_t {
920             return reinterpret_cast<int64_t>(
921                 ten.storage().unsafeGetStorageImpl());
922           })
923       .def(
924           "_jit_try_infer_type",
925           [](py::object obj) -> InferredType {
926             return tryToInferType(std::move(obj));
927           })
928       .def(
929           "_jit_get_te_cuda_pointwise_loop_levels",
930           []() -> int {
931             using namespace torch::jit::tensorexpr;
932             return getTECudaPointwiseLoopLevels();
933           })
934       .def(
935           "_jit_set_te_cuda_pointwise_loop_levels",
936           [](int level) {
937             using namespace torch::jit::tensorexpr;
938             return getTECudaPointwiseLoopLevels() = level;
939           })
940       .def(
941           "_jit_get_te_cuda_pointwise_block_count",
942           []() -> int {
943             using namespace torch::jit::tensorexpr;
944             return getTECudaPointwiseBlockCount();
945           })
946       .def(
947           "_jit_set_te_cuda_pointwise_block_count",
948           [](int block_count) {
949             using namespace torch::jit::tensorexpr;
950             return getTECudaPointwiseBlockCount() = block_count;
951           })
952       .def(
953           "_jit_get_te_cuda_pointwise_block_size",
954           []() -> int {
955             using namespace torch::jit::tensorexpr;
956             return getTECudaPointwiseBlockSize();
957           })
958       .def(
959           "_jit_set_te_cuda_pointwise_block_size",
960           [](int block_size) {
961             using namespace torch::jit::tensorexpr;
962             return getTECudaPointwiseBlockSize() = block_size;
963           })
964       .def("_jit_set_texpr_fuser_enabled", &setTensorExprFuserEnabled)
965       .def("_jit_texpr_fuser_enabled", &tensorExprFuserEnabled)
966       .def("_jit_texpr_fallback_allowed", &tensorexpr::fallbackAllowed)
967       .def("_jit_texpr_set_fallback_allowed", &tensorexpr::setFallbackAllowed)
968       .def("_jit_set_texpr_reductions_enabled", &setTexprReductionsEnabled)
969       .def(
970           "_jit_set_texpr_dynamic_shape_enabled",
971           &setTensorExprDynamicShapeFusionEnabled)
972       .def(
973           "_jit_texpr_dynamic_shape_enabled",
974           &tensorExprDynamicShapeFusionEnabled)
975       .def("_jit_texpr_reductions_enabled", &texprReductionsEnabled)
976       .def(
977           "_jit_set_te_generate_block_code",
978           [](bool gen_block_code) {
979             using namespace torch::jit::tensorexpr;
980             return getTEGenerateBlockCode() = gen_block_code;
981           })
982       .def(
983           "_jit_get_te_generate_block_code",
984           []() -> bool {
985             using namespace torch::jit::tensorexpr;
986             return getTEGenerateBlockCode();
987           })
988       .def(
989           "_jit_get_te_must_use_llvm_cpu",
990           []() -> bool {
991             using namespace torch::jit::tensorexpr;
992             return getTEMustUseLLVMOnCPU();
993           })
994       .def(
995           "_jit_set_te_must_use_llvm_cpu",
996           [](bool use_llvm) {
997             using namespace torch::jit::tensorexpr;
998             getTEMustUseLLVMOnCPU() = use_llvm;
999           })
1000       .def(
1001           "_jit_cat_wo_conditionals",
1002           [](bool optimize_cat) {
1003             using namespace torch::jit::tensorexpr;
1004             getCatWoConditionals() = optimize_cat;
1005           })
1006       .def(
1007           "_jit_opt_conditionals",
1008           [](bool opt_conds) {
1009             using namespace torch::jit::tensorexpr;
1010             getOptConditionals() = opt_conds;
1011           })
1012       .def(
1013           "_llvm_enabled",
1014           []() {
1015 #ifdef TORCH_ENABLE_LLVM
1016             return true;
1017 #else
1018             return false;
1019 #endif
1020           })
1021       .def(
1022           "_jit_pass_fuse_tensorexprs",
1023           [](std::shared_ptr<Graph>& g) {
1024             FuseTensorExprs(g);
1025             RemoveTensorTypeSpecializations(g);
1026           })
1027       .def(
1028           "_jit_fuser_get_fused_kernel_code",
1029           [](Graph& g, const std::vector<at::Tensor>& inps) {
1030             return debugGetFusedKernelCode(g, inps);
1031           })
1032       .def(
1033           "_jit_pass_remove_dropout",
1034           [](script::Module& module) { return removeDropout(module); })
1035       .def(
1036           "_jit_pass_refine_tuple_types",
1037           [](std::shared_ptr<Graph>& graph) { return RefineTupleTypes(graph); })
1038       .def(
1039           "_jit_pass_transform_conv1d_to_conv2d",
1040           [](std::shared_ptr<Graph>& graph) {
1041             return transformConv1dToConv2d(graph);
1042           })
1043       .def(
1044           "_jit_pass_transform_conv1d_to_conv2d",
1045           [](script::Module& module) {
1046             return transformConv1dToConv2d(module);
1047           })
1048       .def(
1049           "_jit_pass_insert_prepacked_ops",
1050           [](std::shared_ptr<Graph>& graph) {
1051             return insertPrePackedOps(graph);
1052           })
1053       .def(
1054           "_jit_pass_insert_prepacked_ops",
1055           [](script::Module& module) { return insertPrePackedOps(module); })
1056       .def(
1057           "_jit_pass_fuse_clamp_w_prepacked_linear_conv",
1058           [](script::Module& module) {
1059             return fusePrePackedLinearConvWithClamp(module);
1060           })
1061       .def(
1062           "_jit_pass_fold_prepacking_ops",
1063           [](script::Module& module) { return FoldPrePackingOps(module); })
1064       .def(
1065           "_jit_pass_optimize_for_mobile",
1066           [](script::Module& module,
1067              std::set<MobileOptimizerType>& optimization_blocklist,
1068              std::vector<std::string>& preserved_methods) {
1069             return optimizeForMobile(
1070                 module, optimization_blocklist, preserved_methods);
1071           })
1072       .def(
1073           "_hack_do_not_use_clone_module_with_class",
1074           [](script::Module& module,
1075              std::vector<std::string>& ignored_methods,
1076              std::vector<std::string>& ignored_attributes) {
1077             const bool inplace = false;
1078             const std::unordered_set<std::string> ignored_methods_set(
1079                 ignored_methods.begin(), ignored_methods.end());
1080             const std::unordered_set<std::string> ignored_attributes_set(
1081                 ignored_attributes.begin(), ignored_attributes.end());
1082             return module.clone(
1083                 inplace, ignored_methods_set, ignored_attributes_set);
1084           })
1085       .def(
1086           "_jit_pass_vulkan_insert_prepacked_ops",
1087           [](std::shared_ptr<Graph>& graph) {
1088             return vulkanInsertPrePackedOps(graph);
1089           })
1090       .def(
1091           "_jit_pass_vulkan_insert_prepacked_ops",
1092           [](script::Module& module) {
1093             return vulkanInsertPrePackedOps(module);
1094           })
1095       .def(
1096           "_jit_pass_vulkan_fuse_clamp_w_prepacked_conv",
1097           [](script::Module& module) {
1098             return vulkanFusePrePackedConvWithClamp(module);
1099           })
1100       .def(
1101           "_jit_pass_vulkan_fold_prepacking_ops",
1102           [](script::Module& module) {
1103             return vulkanFoldPrePackingOps(module);
1104           })
1105       .def(
1106           "_jit_pass_vulkan_optimize_for_mobile",
1107           [](script::Module& module,
1108              std::set<MobileOptimizerType>& optimization_blocklist,
1109              std::vector<std::string>& preserved_methods) {
1110             return vulkanOptimizeForMobile(
1111                 module, optimization_blocklist, preserved_methods);
1112           })
1113       .def(
1114           "_jit_pass_metal_insert_prepacked_ops",
1115           [](std::shared_ptr<Graph>& graph) {
1116             return metalInsertPrePackedOps(graph);
1117           })
1118       .def(
1119           "_jit_pass_metal_insert_prepacked_ops",
1120           [](script::Module& module) {
1121             return metalInsertPrePackedOps(module);
1122           })
1123       .def(
1124           "_jit_pass_metal_fuse_clamp_w_prepacked_conv",
1125           [](script::Module& module) {
1126             return metalFusePrePackedConvWithClamp(module);
1127           })
1128       .def(
1129           "_jit_pass_metal_fold_prepacking_ops",
1130           [](script::Module& module) { return metalFoldPrePackingOps(module); })
1131       .def(
1132           "_jit_pass_metal_optimize_for_mobile",
1133           [](script::Module& module,
1134              std::vector<std::string>& preserved_methods) {
1135             return metalOptimizeForMobile(module, preserved_methods);
1136           })
1137       .def(
1138           "_jit_pass_filter_non_tensor_arguments",
1139           [](std::map<std::string, IValue> params) {
1140             std::map<std::string, at::Tensor> retval;
1141             for (auto& kv : params) {
1142               if (kv.second.isTensor()) {
1143                 retval[kv.first] = std::move(kv.second).toTensor();
1144               }
1145             }
1146             return retval;
1147           })
1148       .def("_jit_pass_batch_mm", BatchMM)
1149       .def(
1150           "_jit_decay_packed_param_input_types",
1151           [](Graph& g) {
1152             for (Value* i : g.inputs()) {
1153               if (i->type() ==
1154                       getCustomClass(
1155                           "__torch__.torch.classes.quantized.Conv2dPackedParamsBase") ||
1156                   i->type() ==
1157                       getCustomClass(
1158                           "__torch__.torch.classes.quantized.Conv3dPackedParamsBase") ||
1159                   i->type() ==
1160                       getCustomClass(
1161                           "__torch__.torch.classes.quantized.LinearPackedParamsBase")) {
1162                 // Dummy CompleteTensorType to appease ONNX validator.
1163                 i->setType(TensorType::create(
1164                     at::kQInt8,
1165                     c10::kCPU,
1166                     std::vector<int64_t>{1},
1167                     std::vector<int64_t>{1},
1168                     std::nullopt));
1169               }
1170             }
1171           })
1172       .def("_jit_set_utf8_decoding_ignore", &setUTF8DecodingIgnore);
1173 
1174   // NB: This isn't actually used for regular PyTorch symbolic tracing;
1175   // XLA is what needs this
1176 #define SYMNODE_UNARY(n) .def(#n, [](const c10::SymNode& a) { return a->n(); })
1177 #define SYMNODE_BINARY(n) \
1178   .def(#n, [](const c10::SymNode& a, const c10::SymNode& b) { return a->n(b); })
1179 #define SYMNODE_SIZES_STRIDES(n)                \
1180   .def(                                         \
1181       #n,                                       \
1182       [](const c10::SymNode& a,                 \
1183          c10::ArrayRef<c10::SymNode> sizes,     \
1184          c10::ArrayRef<c10::SymNode> strides) { \
1185         return a->n(sizes, strides);            \
1186       })
1187   auto symnode_class =
1188       py::class_<c10::SymNodeImpl, c10::SymNode>(m, "_SymNode")
1189       // clang-format off
1190       // These DO NOT install magic methods; the SymInt/SymFloat wrapper in
1191       // Python is responsible for this
1192       SYMNODE_UNARY(clone)
1193       SYMNODE_UNARY(is_int)
1194       SYMNODE_UNARY(is_float)
1195       SYMNODE_UNARY(is_bool)
1196       SYMNODE_UNARY(bool_)
1197       SYMNODE_UNARY(int_)
1198       SYMNODE_UNARY(sym_float)
1199       SYMNODE_BINARY(add)
1200       SYMNODE_BINARY(sub)
1201       SYMNODE_BINARY(mul)
1202       SYMNODE_BINARY(truediv)
1203       SYMNODE_BINARY(int_truediv)
1204       SYMNODE_BINARY(float_truediv)
1205       SYMNODE_BINARY(pow)
1206       SYMNODE_BINARY(float_pow)
1207       SYMNODE_BINARY(pow_by_natural)
1208       SYMNODE_BINARY(floordiv)
1209       SYMNODE_BINARY(int_floordiv)
1210       SYMNODE_BINARY(mod)
1211       SYMNODE_BINARY(eq)
1212       SYMNODE_BINARY(ne)
1213       SYMNODE_BINARY(gt)
1214       SYMNODE_BINARY(lt)
1215       SYMNODE_BINARY(le)
1216       SYMNODE_BINARY(ge)
1217       SYMNODE_BINARY(sym_min)
1218       SYMNODE_BINARY(sym_max)
1219       SYMNODE_BINARY(sym_and)
1220       SYMNODE_BINARY(sym_or)
1221       SYMNODE_UNARY(sym_not)
1222       SYMNODE_UNARY(ceil)
1223       SYMNODE_UNARY(floor)
1224       SYMNODE_UNARY(neg)
1225       SYMNODE_SIZES_STRIDES(is_contiguous)
1226       SYMNODE_SIZES_STRIDES(is_channels_last_contiguous_2d)
1227       SYMNODE_SIZES_STRIDES(is_channels_last_contiguous_3d)
1228       SYMNODE_SIZES_STRIDES(is_channels_last_strides_2d)
1229       SYMNODE_SIZES_STRIDES(is_channels_last_strides_3d)
1230       SYMNODE_SIZES_STRIDES(is_non_overlapping_and_dense)
1231       .def(
1232           "guard_int",
1233           [](const c10::SymNode& a, const char* file, int64_t line) {
1234             return a->guard_int(file, line);
1235           })
1236       .def(
1237           "guard_bool",
1238           [](const c10::SymNode& a, const char* file, int64_t line) {
1239             return a->guard_bool(file, line);
1240           })
1241       .def(
1242           "guard_float",
1243           [](const c10::SymNode& a, const char* file, int64_t line) {
1244             return a->guard_float(file, line);
1245           })
1246       .def(
1247           "expect_true",
1248           [](const c10::SymNode& a, const char* file, int64_t line) {
1249             return a->expect_true(file, line);
1250           })
1251       .def(
1252           "expect_size",
1253           [](const c10::SymNode& a, const char* file, int64_t line) {
1254             return a->expect_size(file, line);
1255           })
1256       .def(
1257           "guard_size_oblivious",
1258           [](const c10::SymNode& a, const char* file, int64_t line) {
1259             return a->guard_size_oblivious(file, line);
1260           })
1261       .def(
1262           "has_hint",
1263           [](const c10::SymNode& a) {
1264             return a->has_hint();
1265           })
1266       .def(
1267           "wrap_int",
1268           [](const c10::SymNode& a, int64_t b) {
1269             return a->wrap_int(b);
1270           })
1271       .def(
1272           "wrap_float",
1273           [](const c10::SymNode& a, double b) {
1274             return a->wrap_float(b);
1275           })
1276       .def(
1277           "wrap_bool",
1278           [](const c10::SymNode& a, bool b) {
1279             return a->wrap_bool(b);
1280           })
1281       .def(
1282           "__str__",
1283           [](const c10::SymNode& a) { return a->str(); })
1284       .def(
1285           "__repr__",
1286           [](const c10::SymNode& a) { return a->str(); })
1287       .def(
1288           "_graph_repr",
1289           [](const c10::SymNode& a) { return a->_graph_repr(); })
1290       .def(
1291           "is_constant",
1292           [](const c10::SymNode& node){
1293             return node->is_constant();
1294           })
1295       .def(
1296           "is_nested_int",
1297           [](const c10::SymNode& node) {
1298             return node->is_nested_int();
1299           })
1300       .def(
1301           "is_symbolic",
1302           [](const c10::SymNode& node) {
1303             return node->is_symbolic();
1304           })
1305       .def(
1306           "nested_int",
1307           [](const c10::SymNode& node) {
1308             return node->nested_int();
1309           })
1310       .def(
1311           "nested_int_coeff",
1312           [](const c10::SymNode& node) {
1313             return node->nested_int_coeff();
1314           })
1315       .def(
1316           "__deepcopy__",
1317           [](const c10::SymNode& node, py::handle memo) {
1318             return node->clone();
1319           });
1320 
1321   // clang-format on
1322 
1323   // NOLINTNEXTLINE(bugprone-unused-raii)
1324   py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
1325       .def("__repr__", [](CompleteArgumentSpec& self) {
1326         std::ostringstream s;
1327         s << self;
1328         return s.str();
1329       });
1330   // NOLINTNEXTLINE(bugprone-unused-raii)
1331   py::class_<ArgumentSpec>(m, "ArgumentSpec");
1332   py::class_<Code>(m, "Code")
1333       .def(
1334           "grad_executor_states",
1335           [](Code& c) {
1336             std::vector<GraphExecutorState> states;
1337             for (auto& e : c.grad_executors()) {
1338               states.emplace_back(e->getDebugState());
1339             }
1340             return states;
1341           })
1342       .def(
1343           "differentiable_op_executor_states",
1344           [](Code& c) {
1345             std::vector<GraphExecutorState> states;
1346             for (auto& e : c.diff_graph_op_executors()) {
1347               if (e->isOptimized()) {
1348                 states.emplace_back(e->getDebugState());
1349               } else {
1350                 // we leave an empty entry for node that doesn't have an
1351                 // optimized plan
1352                 states.emplace_back();
1353               }
1354             }
1355             return states;
1356           })
1357       .def("num_bailouts", [](Code& c) { return c.num_bailouts(); })
1358       .def("request_bailout", [](Code& c, size_t index) {
1359         c.request_bailout(index);
1360       });
1361 
1362   py::class_<ExecutionPlan>(m, "ExecutionPlan")
1363       .def_property_readonly("graph", [](ExecutionPlan& s) { return s.graph; })
1364       .def_property_readonly("code", [](ExecutionPlan& s) { return s.code; });
1365 
1366   py::class_<Gradient>(m, "Gradient")
1367       .def_property_readonly("f", [](Gradient& m) { return m.f; })
1368       .def_property_readonly("df", [](Gradient& m) { return m.df; })
1369       .def_property_readonly(
1370           "f_real_outputs", [](Gradient& m) { return m.f_real_outputs; })
1371       .def_property_readonly(
1372           "df_input_vjps", [](Gradient& m) { return m.df_input_vjps; })
1373       .def_property_readonly(
1374           "df_input_captured_inputs",
1375           [](Gradient& m) { return m.df_input_captured_inputs; })
1376       .def_property_readonly(
1377           "df_input_captured_outputs",
1378           [](Gradient& m) { return m.df_input_captured_outputs; })
1379       .def_property_readonly(
1380           "df_output_vjps", [](Gradient& m) { return m.df_output_vjps; });
1381 
1382   py::class_<GraphExecutorState>(m, "GraphExecutorState")
1383       .def_property_readonly(
1384           "graph", [](GraphExecutorState& s) { return s.graph; })
1385       .def_property_readonly(
1386           "execution_plans",
1387           [](GraphExecutorState& s) { return s.execution_plans; })
1388       .def_property_readonly(
1389           "fallback", [](GraphExecutorState& s) { return s.fallback; });
1390 
1391   py::class_<PyTorchStreamWriter>(m, "PyTorchFileWriter")
1392       .def(py::init<std::string>())
1393       .def(py::init([](const py::object& buffer) {
1394         auto writer_func = [=](const void* data, size_t size) {
1395           // Writing an empty file is a noop
1396           if (size == 0) {
1397             return size;
1398           }
1399           py::gil_scoped_acquire acquire;
1400           if (!data) {
1401             // See [Note: write_record_metadata]
1402             buffer.attr("seek")(
1403                 size, py::module::import("os").attr("SEEK_CUR"));
1404           } else {
1405             auto memory_view = py::memoryview::from_memory(
1406                 reinterpret_cast<const char*>(data), size);
1407             buffer.attr("write")(std::move(memory_view));
1408           }
1409           return size;
1410         };
1411         return std::make_unique<PyTorchStreamWriter>(std::move(writer_func));
1412       }))
1413       .def(py::init<const std::function<size_t(const void*, size_t)>&>())
1414       // [Note: write_record_metadata]
1415       // The write_record_metadata function is intended to write metadata (i.e.
1416       // the zipfile header and end of central directory record) for a file
1417       // while reserving nbytes of space for the file for the bytes of the
1418       // actual file to be added in later. This functionality is achieved by
1419       // defining `m_pWrite` to seek instead of write if the buffer passed is a
1420       // nullptr. This has implications on CRC-32 which will not be written at
1421       // write_record_metadata time, and will not be combined with the hash in
1422       // combined_uncomp_crc32_. We define this in `m_pWrite` rather than
1423       // extending the interface of miniz to have an `m_pSeek` since different
1424       // versions of miniz are used in fbcode/oss.
1425       .def(
1426           "write_record_metadata",
1427           [](PyTorchStreamWriter& self, const std::string& name, size_t size) {
1428             return self.writeRecord(name, nullptr, size);
1429           })
1430       .def(
1431           "write_record",
1432           [](PyTorchStreamWriter& self,
1433              const std::string& name,
1434              const char* data,
1435              size_t size) {
1436             // Since we don't know where the data come from, we cannot
1437             // release the GIL in this overload
1438             return self.writeRecord(name, data, size);
1439           })
1440       .def(
1441           "write_record",
1442           [](PyTorchStreamWriter& self,
1443              const std::string& name,
1444              py::bytes data,
1445              size_t size) {
1446             // It is not clear from the doc but according to CPython own code,
1447             // it is ok to use the result of PyBytes_AsString without the GIL
1448             // being held
1449             // https://github.com/python/cpython/blob/e2a3e4b7488aff6fdc704a0f258bc315e96c1d6e/Objects/stringlib/join.h#L67
1450             const char* data_str = PyBytes_AsString(data.ptr());
1451             py::gil_scoped_release release;
1452             return self.writeRecord(name, data_str, size);
1453           })
1454       .def(
1455           "write_record",
1456           [](PyTorchStreamWriter& self,
1457              const std::string& name,
1458              const c10::Storage& data,
1459              size_t size) {
1460             // Reading Tensor data is always ok without the GIL held
1461             py::gil_scoped_release release;
1462             return self.writeRecord(
1463                 name, reinterpret_cast<const char*>(data.data()), size);
1464           })
1465       .def(
1466           "write_record",
1467           [](PyTorchStreamWriter& self,
1468              const std::string& name,
1469              uintptr_t data,
1470              size_t size) {
1471             TORCH_WARN_ONCE(
1472                 "write_record(): Passing Storage by data pointer is deprecated and will be an error in ",
1473                 "the future, please pass the Storage object instead.");
1474             return self.writeRecord(
1475                 name, reinterpret_cast<const char*>(data), size);
1476           })
1477       .def("write_end_of_file", &PyTorchStreamWriter::writeEndOfFile)
1478       .def("set_min_version", &PyTorchStreamWriter::setMinVersion)
1479       .def("archive_name", &PyTorchStreamWriter::archiveName)
1480       .def("serialization_id", &PyTorchStreamWriter::serializationId)
1481       .def(
1482           "get_all_written_records",
1483           &PyTorchStreamWriter::getAllWrittenRecords);
1484 
1485   py::enum_<MobileOptimizerType>(m, "_MobileOptimizerType")
1486       .value("CONV_BN_FUSION", MobileOptimizerType::CONV_BN_FUSION)
1487       .value(
1488           "INSERT_FOLD_PREPACK_OPS",
1489           MobileOptimizerType::INSERT_FOLD_PREPACK_OPS)
1490       .value("REMOVE_DROPOUT", MobileOptimizerType::REMOVE_DROPOUT)
1491       .value("FUSE_ADD_RELU", MobileOptimizerType::FUSE_ADD_RELU)
1492       .value(
1493           "HOIST_CONV_PACKED_PARAMS",
1494           MobileOptimizerType::HOIST_CONV_PACKED_PARAMS)
1495       .value(
1496           "VULKAN_AUTOMATIC_GPU_TRANSFER",
1497           MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER);
1498 
1499   // This allows PyTorchStreamReader to read from a Python buffer. It requires
1500   // that the buffer implement `seek()`, `tell()`, and `read()`.
1501   class BufferAdapter : public caffe2::serialize::ReadAdapterInterface {
1502    public:
1503     BufferAdapter(const py::object& buffer) : buffer_(buffer) {
1504       // Jump to the end of the buffer to get its size
1505       auto current = buffer.attr("tell")();
1506       start_offset_ = py::cast<size_t>(current);
1507       buffer.attr("seek")(current, py::module::import("os").attr("SEEK_END"));
1508       size_ = py::cast<size_t>(buffer.attr("tell")()) - start_offset_;
1509       buffer.attr("seek")(current);
1510       // If we can read directly into a buffer, do that instead of an extra copy
1511       // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
1512       use_readinto_ = py::hasattr(buffer, "readinto");
1513     }
1514 
1515     size_t size() const override {
1516       return size_;
1517     }
1518 
1519     THPObjectPtr getMemview(void* buf, size_t n) const {
1520       THPObjectPtr memview(PyMemoryView_FromMemory(
1521           reinterpret_cast<char*>(buf), n, PyBUF_WRITE));
1522       if (!memview) {
1523         throw python_error();
1524       }
1525       return memview;
1526     }
1527 
1528     size_t read(uint64_t pos, void* buf, size_t n, const char* what)
1529         const override {
1530       // Seek to desired position (NB: this has to be a Py_ssize_t or Python
1531       // throws a weird error)
1532       Py_ssize_t absolute_pos = start_offset_ + pos;
1533       buffer_.attr("seek")(absolute_pos);
1534 
1535       if (use_readinto_) {
1536         auto memview = getMemview(buf, n);
1537         auto res =
1538             PyObject_CallMethod(buffer_.ptr(), "readinto", "O", memview.get());
1539         if (res) {
1540           int64_t i = static_cast<int64_t>(PyLong_AsLongLong(res));
1541           Py_DECREF(res);
1542           if (i > 0) {
1543             return i;
1544           }
1545         }
1546       }
1547 
1548       // Read bytes into `buf` from the buffer
1549       std::string bytes = py::cast<std::string>(buffer_.attr("read")(n));
1550       std::copy(
1551           bytes.data(),
1552           bytes.data() + bytes.size(),
1553           reinterpret_cast<char*>(buf));
1554       return bytes.size();
1555     }
1556 
1557     py::object buffer_;
1558     size_t size_;
1559     size_t start_offset_;
1560     bool use_readinto_{};
1561   };
1562 
1563   py::class_<PyTorchStreamReader, std::shared_ptr<PyTorchStreamReader>>(
1564       m, "PyTorchFileReader")
1565       .def(py::init<std::string>())
1566       .def(py::init([](const py::object& buffer) {
1567         auto adapter = std::make_unique<BufferAdapter>(buffer);
1568         return std::make_shared<PyTorchStreamReader>(std::move(adapter));
1569       }))
1570       .def(
1571           "get_record",
1572           [](PyTorchStreamReader& self, const std::string& key) {
1573             auto [data, size] = self.getRecord(key);
1574             return py::bytes(reinterpret_cast<const char*>(data.get()), size);
1575           })
1576       .def(
1577           "has_record",
1578           [](PyTorchStreamReader& self, const std::string& key) {
1579             return self.hasRecord(key);
1580           })
1581       .def(
1582           "get_storage_from_record",
1583           [](PyTorchStreamReader& self,
1584              const std::string& key,
1585              size_t numel,
1586              py::object data_type_obj) {
1587             at::DataPtr data(std::get<0>(self.getRecord(key)));
1588             auto scalar_type =
1589                 reinterpret_cast<THPDtype*>(data_type_obj.ptr())->scalar_type;
1590 
1591             c10::Storage storage(
1592                 c10::Storage::use_byte_size_t(),
1593                 numel * elementSize(scalar_type),
1594                 std::move(data),
1595                 /*allocator=*/nullptr,
1596                 /*resizable=*/false);
1597             auto ptr =
1598                 c10::make_intrusive<at::TensorImpl, at::UndefinedTensorImpl>(
1599                     std::move(storage),
1600                     at::DispatchKeySet(),
1601                     at::CPU(scalar_type).typeMeta());
1602             return at::Tensor(std::move(ptr));
1603           })
1604       .def("serialization_id", &PyTorchStreamReader::serializationId)
1605       .def(
1606           "get_all_records",
1607           [](PyTorchStreamReader& self) { return self.getAllRecords(); })
1608       .def(
1609           "get_record_offset",
1610           [](PyTorchStreamReader& self, const std::string& key) {
1611             return self.getRecordOffset(key);
1612           });
1613 
1614   // Used by torch.Package to coordinate deserialization of storages across
1615   // ScriptModules and eager modules
1616   py::class_<
1617       DeserializationStorageContext,
1618       std::shared_ptr<DeserializationStorageContext>>(
1619       m, "DeserializationStorageContext")
1620       .def(py::init<>())
1621       .def(
1622           "get_storage",
1623           [](DeserializationStorageContext& self,
1624              const std::string& name,
1625              py::object data_type_obj) {
1626             c10::Storage storage = self.getStorage(name);
1627             auto scalar_type =
1628                 reinterpret_cast<THPDtype*>(data_type_obj.ptr())->scalar_type;
1629             auto ptr =
1630                 c10::make_intrusive<at::TensorImpl, at::UndefinedTensorImpl>(
1631                     std::move(storage),
1632                     at::DispatchKeySet(),
1633                     at::CPU(scalar_type).typeMeta());
1634 
1635             return at::Tensor(std::move(ptr));
1636           })
1637       .def(
1638           "add_storage",
1639           [](DeserializationStorageContext& self,
1640              const std::string& name,
1641              const at::Tensor& tensor) {
1642             return self.addStorage(name, tensor.storage());
1643           })
1644       .def("has_storage", &DeserializationStorageContext::hasStorage);
1645 
1646   m.def(
1647       "_get_schema",
1648       [](const std::string& op_name, const std::string& overload_name) {
1649         try {
1650           auto symbol = Symbol::fromQualString(op_name);
1651           auto operations = getAllOperatorsFor(symbol);
1652           for (const auto& op : operations) {
1653             if (op->schema().overload_name() == overload_name) {
1654               return op->schema();
1655             }
1656           }
1657           throw std::runtime_error("Found no matching schema");
1658         } catch (const c10::Error& e) {
1659           auto msg = torch::get_cpp_stacktraces_enabled()
1660               ? e.what()
1661               : e.what_without_backtrace();
1662           throw std::runtime_error(msg);
1663         }
1664       });
1665 
1666   m.def(
1667       "_get_operation_overload",
1668       [](const std::string& op_name,
1669          const std::string& overload_name) -> std::optional<py::tuple> {
1670         try {
1671           auto symbol = Symbol::fromQualString(op_name);
1672           auto operations = getAllOperatorsFor(symbol);
1673           bool allow_numbers_as_tensors = opAllowsNumbersAsTensors(symbol);
1674           for (const auto& op : operations) {
1675             if (op->schema().overload_name() == overload_name) {
1676               auto func = py::cpp_function(
1677                   [op, symbol, allow_numbers_as_tensors](
1678                       const py::args& args, const py::kwargs& kwargs) {
1679                     ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
1680                     return _get_operation_for_overload_or_packet(
1681                         {op}, symbol, args, kwargs, /*is_overload*/ true);
1682                   });
1683               auto func_dk =
1684                   py::cpp_function([op, symbol, allow_numbers_as_tensors](
1685                                        c10::DispatchKey dk_,
1686                                        const py::args& args,
1687                                        const py::kwargs& kwargs) {
1688                     std::optional<c10::DispatchKey> dk =
1689                         std::make_optional(dk_);
1690                     ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
1691                     return _get_operation_for_overload_or_packet(
1692                         {op}, symbol, args, kwargs, /*is_overload*/ true, dk);
1693                   });
1694               return std::make_optional(
1695                   py::make_tuple(func, func_dk, py::cast(op->getTags().vec())));
1696             }
1697           }
1698           return std::nullopt;
1699         } catch (const c10::Error& e) {
1700           auto msg = torch::get_cpp_stacktraces_enabled()
1701               ? e.what()
1702               : e.what_without_backtrace();
1703           throw std::runtime_error(msg);
1704         }
1705       });
1706 
1707   m.def(
1708       "_check_schema_allow_fake_script_object",
1709       [](const FunctionSchema& schema,
1710          const py::args& args,
1711          const py::kwargs& kwargs) {
1712         // checkSchemaAllowFakeScriptObject will throw runtime error if there is
1713         // a schema mismatch. Otherwise, it returns true.
1714         return checkSchemaAllowFakeScriptObject(schema, args, kwargs);
1715       });
1716 
1717   m.def(
1718       "_jit_resolve_packet",
1719       [](const char* op_name, py::args args, const py::kwargs& kwargs) {
1720         try {
1721           auto symbol = Symbol::fromQualString(op_name);
1722           bool allow_numbers_as_tensors = opAllowsNumbersAsTensors(symbol);
1723           ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
1724           const auto overloads = getAllSortedOperatorsFor(symbol);
1725           auto opWithStack = getOpWithStack(overloads, std::move(args), kwargs);
1726           std::shared_ptr<Operator> overload = std::get<0>(opWithStack);
1727           auto result = overload->schema().overload_name();
1728           if (result.empty()) {
1729             result = "default";
1730           }
1731           return result;
1732         } catch (const c10::Error& e) {
1733           auto msg = torch::get_cpp_stacktraces_enabled()
1734               ? e.what()
1735               : e.what_without_backtrace();
1736           throw std::runtime_error(msg);
1737         }
1738       });
1739 
1740   m.def(
1741       "_jit_get_operation",
1742       [](const std::string& op_name) {
1743         try {
1744           auto symbol = Symbol::fromQualString(op_name);
1745           const auto sortedOps = getAllSortedOperatorsFor(symbol);
1746           if (sortedOps.empty()) {
1747             // No such operator
1748             return py::make_tuple(py::none(), py::none());
1749           }
1750 
1751           std::ostringstream docstring;
1752           docstring << "Automatically bound operator '" << op_name
1753                     << "' with schema(s):\n";
1754 
1755           for (const auto& op : sortedOps) {
1756             docstring << "  " << op->schema() << "\n";
1757           }
1758 
1759           py::list overload_names;
1760           for (const auto& op : sortedOps) {
1761             overload_names.append(py::str(op->schema().overload_name()));
1762           }
1763 
1764           bool allow_numbers_as_tensors = opAllowsNumbersAsTensors(symbol);
1765 
1766           auto func = py::cpp_function(
1767               [sortedOps, symbol, allow_numbers_as_tensors](
1768                   const py::args& args, const py::kwargs& kwargs) {
1769                 ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
1770                 return _get_operation_for_overload_or_packet(
1771                     sortedOps, symbol, args, kwargs, false);
1772               },
1773               py::name(symbol.toUnqualString()),
1774               py::doc(docstring.str().c_str()));
1775           return py::make_tuple(func, overload_names);
1776         } catch (const c10::Error& e) {
1777           auto msg = torch::get_cpp_stacktraces_enabled()
1778               ? e.what()
1779               : e.what_without_backtrace();
1780           throw std::runtime_error(msg);
1781         }
1782       },
1783       py::arg("qualified_name"));
1784 
1785   m.def(
1786       "_maybe_call_torch_function_for_op_packet",
1787       [](py::handle op_overload_packet,
1788          const py::args& args,
1789          const py::kwargs& kwargs) {
1790         py::list ns_method =
1791             op_overload_packet.attr("_qualified_op_name").attr("split")("::");
1792         auto res = _maybe_handle_torch_function(
1793             py::cast<std::string>(ns_method[0]),
1794             py::cast<std::string>(ns_method[1]),
1795             "",
1796             false,
1797             args,
1798             kwargs);
1799         if (res) {
1800           return py::make_tuple(true, *res);
1801         } else {
1802           return py::make_tuple(false, py::none());
1803         }
1804       });
1805 
1806   m.def(
1807       "parse_ir",
1808       [](const std::string& input, bool parse_tensor_constants) {
1809         auto graph = std::make_shared<Graph>();
1810         parseIR(input, &*graph, parse_tensor_constants);
1811         return graph;
1812       },
1813       py::arg("input"),
1814       py::arg("parse_tensor_constants") = false);
1815   m.def(
1816       "parse_schema",
1817       &parseSchema,
1818       py::arg("schema"),
1819       py::arg("allow_typevars") = true);
1820   m.def("unify_type_list", [](const std::vector<TypePtr>& types) {
1821     std::ostringstream s;
1822     auto type = unifyTypeList(types, s);
1823     if (!type) {
1824       throw std::runtime_error(s.str());
1825     }
1826     return type.value();
1827   });
1828   py::enum_<SchemaArgType>(m, "_SchemaArgType")
1829       .value("input", SchemaArgType::input)
1830       .value("output", SchemaArgType::output);
1831   py::class_<SchemaArgument>(m, "_SchemaArgument")
1832       .def(py::init<SchemaArgType, size_t>())
1833       .def_readwrite("type", &SchemaArgument::type)
1834       .def_readwrite("index", &SchemaArgument::index);
1835   py::class_<SchemaInfo>(m, "_SchemaInfo")
1836       .def(py::init<FunctionSchema>())
1837       .def("is_mutable", [](SchemaInfo& self) { return self.is_mutable(); })
1838       .def(
1839           "is_mutable",
1840           [](SchemaInfo& self, const SchemaArgument& argument) {
1841             return self.is_mutable(argument);
1842           })
1843       .def(
1844           "has_argument",
1845           [](SchemaInfo& self, const std::string& name) {
1846             return self.has_argument(name);
1847           })
1848       .def(
1849           "is_mutable",
1850           [](SchemaInfo& self, const std::string& name) {
1851             return self.is_mutable(name);
1852           })
1853       .def(
1854           "may_alias",
1855           [](SchemaInfo& self,
1856              const SchemaArgument& lhs,
1857              const SchemaArgument& rhs) { return self.may_alias(lhs, rhs); })
1858       .def(
1859           "may_contain_alias",
1860           [](SchemaInfo& self,
1861              const SchemaArgument& lhs,
1862              const SchemaArgument& rhs) {
1863             return self.may_contain_alias(lhs, rhs);
1864           })
1865       .def(
1866           "add_argument_value",
1867           [](SchemaInfo& self,
1868              const std::string& name,
1869              const py::object& value) {
1870             std::optional<IValue> i_value = toTypeInferredIValueOptional(value);
1871             if (i_value) {
1872               // For normalization purposes there is an inconsistency within
1873               // torch.fx that turns all arguments named "self" into "input".
1874               // Thus this check ensures that those arguments are checked
1875               // correctly.
1876               if (name == "input" && !self.hasInputArgumentNamed("input")) {
1877                 self.addArgumentValue("self", *i_value);
1878               } else {
1879                 self.addArgumentValue(name, *i_value);
1880               }
1881             }
1882           })
1883       .def("add_argument_values", [](SchemaInfo& self, const py::dict& values) {
1884         std::unordered_map<std::string, IValue> value_map;
1885         for (const auto& key_pair : values) {
1886           IValue key = toTypeInferredIValue(key_pair.first);
1887           TORCH_INTERNAL_ASSERT(
1888               key.isString(),
1889               "Add argument value keys types should be strings.");
1890           std::optional<IValue> value =
1891               toTypeInferredIValueOptional(key_pair.second);
1892           if (value) {
1893             // For normalization purposes there is an inconsistency within
1894             // torch.fx that
1895             // turns all arguments named "self" into "input". Thus this check
1896             // ensures that those arguments are checked correctly.
1897             if (key.toStringRef() == "input" &&
1898                 !self.hasInputArgumentNamed("input")) {
1899               self.addArgumentValue("self", *value);
1900             } else {
1901               value_map[key.toStringRef()] = *value;
1902             }
1903           }
1904         }
1905         self.addArgumentValues(value_map);
1906       });
1907   py::class_<FunctionSchema>(m, "FunctionSchema")
1908       .def_property_readonly(
1909           "name", [](FunctionSchema& self) { return self.name(); })
1910       .def_property_readonly(
1911           "overload_name",
1912           [](FunctionSchema& self) { return self.overload_name(); })
1913       .def_property_readonly(
1914           "arguments", [](FunctionSchema& self) { return self.arguments(); })
1915       .def_property_readonly(
1916           "returns", [](FunctionSchema& self) { return self.returns(); })
1917       .def(
1918           "is_backward_compatible_with",
1919           [](const FunctionSchema& self, const FunctionSchema& old_schema) {
1920             return self.isBackwardCompatibleWith(old_schema);
1921           })
1922       .def(
1923           "check_forward_compatible_with",
1924           [](const FunctionSchema& self, const FunctionSchema& old_schema) {
1925             std::ostringstream out;
1926             auto result = self.isForwardCompatibleWith(old_schema, out);
1927             return std::make_pair(result, out.str());
1928           })
1929       .def(
1930           "__eq__",
1931           [](const FunctionSchema& self, const FunctionSchema& other) {
1932             return self == other;
1933           })
1934       .def(
1935           "__hash__",
1936           [](const FunctionSchema& self) {
1937             return std::hash<FunctionSchema>{}(self);
1938           })
1939       .def(
1940           "__str__",
1941           [](FunctionSchema& self) {
1942             std::stringstream ss;
1943             ss << self;
1944             return ss.str();
1945           })
1946       .def(
1947           "__repr__",
1948           [](FunctionSchema& self) {
1949             std::stringstream ss;
1950             ss << self;
1951             return ss.str();
1952           })
1953       .def(py::pickle(
1954           [](const FunctionSchema& self) { // __getstate__
1955             std::stringstream ss;
1956             ss << self;
1957             return py::str(ss.str());
1958           },
1959           [](const py::str& schema) { // __setstate__, note: no `self` argument
1960             return parseSchema(schema);
1961           }))
1962       .def_property_readonly(
1963           "is_mutable", [](FunctionSchema& self) { return self.is_mutable(); });
1964   py::class_<Argument>(m, "Argument")
1965       .def_property_readonly("name", [](Argument& self) { return self.name(); })
1966       .def_property_readonly("type", [](Argument& self) { return self.type(); })
1967       .def_property_readonly(
1968           "real_type", [](Argument& self) { return self.real_type(); })
1969       .def_property_readonly(
1970           "N",
1971           [](Argument& self) -> py::object {
1972             return (self.N()) ? py::cast(*self.N()) : py::none();
1973           })
1974       .def_property_readonly(
1975           "default_value",
1976           [](Argument& self) -> py::object {
1977             if (!self.default_value()) {
1978               return py::none();
1979             }
1980             IValue v = *self.default_value();
1981             return toPyObject(std::move(v));
1982           })
1983       .def(
1984           "has_default_value",
1985           [](Argument& self) -> py::bool_ {
1986             return self.default_value().has_value();
1987           })
1988       .def_property_readonly(
1989           "alias_info", [](Argument& self) { return self.alias_info(); })
1990       .def_property_readonly(
1991           "is_out", [](Argument& self) { return self.is_out(); })
1992       .def_property_readonly("kwarg_only", [](Argument& self) -> bool {
1993         return self.kwarg_only();
1994       });
1995   py::class_<AliasInfo>(m, "_AliasInfo")
1996       .def_property_readonly(
1997           "is_write", [](AliasInfo& self) { return self.isWrite(); })
1998       .def_property_readonly(
1999           "before_set",
2000           [](AliasInfo& self) {
2001             std::set<py::str> before_set_python;
2002             for (const auto& set : self.beforeSets()) {
2003               before_set_python.insert(py::str(set.toUnqualString()));
2004             }
2005             return before_set_python;
2006           })
2007       .def_property_readonly("after_set", [](AliasInfo& self) {
2008         std::set<py::str> after_set_python;
2009         for (const auto& set : self.afterSets()) {
2010           after_set_python.insert(py::str(set.toUnqualString()));
2011         }
2012         return after_set_python;
2013       });
2014   m.def("_jit_get_all_schemas", []() {
2015     const std::vector<std::shared_ptr<Operator>>& operations =
2016         getAllOperators();
2017     return fmap(operations, [](const std::shared_ptr<Operator>& op) {
2018       return op->schema();
2019     });
2020   });
2021   m.def("_jit_get_custom_class_schemas", customClassSchemasForBCCheck);
2022   m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) {
2023     auto symbol = Symbol::fromQualString(qualified_name);
2024     const auto& operations = getAllOperatorsFor(symbol);
2025     return fmap(operations, [](const std::shared_ptr<Operator>& op) {
2026       return op->schema();
2027     });
2028   });
2029   m.def("_is_tracing", []() { return jit::tracer::isTracing(); });
2030 
2031   py::class_<PythonFutureWrapper, std::shared_ptr<PythonFutureWrapper>>(
2032       m, "Future")
2033       .def(py::init([](std::vector<c10::Device> devices = {}) {
2034         return std::make_shared<PythonFutureWrapper>(
2035             c10::make_intrusive<c10::ivalue::Future>(
2036                 PyObjectType::get(), std::move(devices)));
2037       }))
2038       .def(
2039           "done",
2040           // Intentionally not releasing GIL
2041           &PythonFutureWrapper::done)
2042       .def(
2043           "value",
2044           &PythonFutureWrapper::value,
2045           py::call_guard<py::gil_scoped_release>())
2046       .def(
2047           "wait",
2048           &PythonFutureWrapper::wait,
2049           py::call_guard<py::gil_scoped_release>())
2050       .def(
2051           "then",
2052           &PythonFutureWrapper::then,
2053           py::call_guard<py::gil_scoped_release>())
2054       .def(
2055           "add_done_callback",
2056           &PythonFutureWrapper::add_done_callback,
2057           py::call_guard<py::gil_scoped_release>())
2058       .def(
2059           "set_result",
2060           // Intentionally not releasing GIL
2061           &PythonFutureWrapper::markCompleted)
2062       .def(
2063           "_set_unwrap_func",
2064           // Intentionally not releasing GIL as this just does an assign
2065           [](PythonFutureWrapper& self, py::function unwrapFunc) {
2066             auto functionGuard =
2067                 std::make_shared<torch::jit::PythonFunctionGuard>(
2068                     std::move(unwrapFunc));
2069 
2070             std::function<void(py::object)> pf =
2071                 [functionGuard(std::move(functionGuard))](
2072                     const py::object& inp) {
2073                   return functionGuard->func_(inp);
2074                 };
2075             self.unwrap_func = std::move(pf);
2076           })
2077       .def(
2078           py::pickle(
2079               /* __getstate__ */
2080               [](const PythonFutureWrapper& /* unused */) {
2081                 TORCH_CHECK(false, "Can not pickle torch.futures.Future");
2082                 // Note that this return has no meaning since we always
2083                 // throw, it's only here to satisfy Pybind API's
2084                 // requirement.
2085                 return py::make_tuple();
2086               },
2087               /* __setstate__ */
2088               [](const py::tuple& /* unused */) { // NOLINT
2089                 TORCH_CHECK(false, "Can not unpickle torch.futures.Future");
2090                 // Note that this return has no meaning since we always
2091                 // throw, it's only here to satisfy PyBind's API
2092                 // requirement.
2093                 return nullptr;
2094               }),
2095           py::call_guard<py::gil_scoped_release>());
2096 
2097   py::class_<PythonAwaitWrapper, std::shared_ptr<PythonAwaitWrapper>>(
2098       m, "_Await")
2099       .def(
2100           "wait",
2101           &PythonAwaitWrapper::wait,
2102           py::call_guard<py::gil_scoped_release>())
2103       .def("fn", &PythonAwaitWrapper::fn)
2104       .def("args", &PythonAwaitWrapper::args)
2105       .def("type", &PythonAwaitWrapper::type)
2106       .def("is_nowait", &PythonAwaitWrapper::is_nowait)
2107       .def(
2108           "__getattr__",
2109           [](PythonAwaitWrapper& self, const std::string& name) -> py::object {
2110             // In eager mode allow Await[W] to be used as W, redirecting getattr
2111             // to the result of delayed function.
2112             return py::getattr(self.wait(), name.c_str(), py::none());
2113           })
2114       .def(
2115           py::pickle(
2116               /* __getstate__ */
2117               [](const PythonAwaitWrapper& /* unused */) {
2118                 TORCH_CHECK(false, "Can not pickle torch.jit._Await");
2119                 // Note that this return has no meaning since we always
2120                 // throw, it's only here to satisfy Pybind API's
2121                 // requirement.
2122                 return py::make_tuple();
2123               },
2124               /* __setstate__ */
2125               [](const py::tuple& /* unused */) { // NOLINT
2126                 TORCH_CHECK(false, "Can not unpickle torch.jit._Await");
2127                 // Note that this return has no meaning since we always
2128                 // throw, it's only here to satisfy PyBind's API
2129                 // requirement.
2130                 return nullptr;
2131               }),
2132           py::call_guard<py::gil_scoped_release>());
2133 
2134   m.def("_is_alias_of", [](const py::object& self, const py::object& other) {
2135     std::optional<IValue> self_value = toTypeInferredIValueOptional(self);
2136     std::optional<IValue> other_value = toTypeInferredIValueOptional(other);
2137 
2138     // Only return true if we are certain that self and other are aliasing.
2139     if (!self_value || !other_value) {
2140       return false;
2141     }
2142     return self_value->isAliasOf(*other_value);
2143   });
2144   m.def("_overlaps", [](const py::object& self, const py::object& other) {
2145     std::optional<IValue> self_value = toTypeInferredIValueOptional(self);
2146     std::optional<IValue> other_value = toTypeInferredIValueOptional(other);
2147 
2148     // Only return true if we are certain that self and other are overlapping.
2149     if (!self_value || !other_value) {
2150       return false;
2151     }
2152     return self_value->overlaps(*other_value);
2153   });
2154   m.def("_awaitable", [](const py::args& args, const py::kwargs& kwargs) {
2155     AT_ASSERT(!args.empty());
2156     py::tuple args_tup(args.size() - 1);
2157     for (const auto i : c10::irange(1, args.size())) {
2158       args_tup[i - 1] = args[i];
2159     }
2160     return std::make_shared<PythonAwaitWrapper>(
2161         py::cast<py::function>(args[0]), std::move(args_tup));
2162   });
2163   m.def("_awaitable_nowait", [](py::handle input) {
2164     return std::make_shared<PythonAwaitWrapper>(input);
2165   });
2166   m.def(
2167       "_awaitable_wait", [](const std::shared_ptr<PythonAwaitWrapper>& py_aw) {
2168         TORCH_CHECK(py_aw, "Await can't be None");
2169         return py_aw->wait();
2170       });
2171   m.def("fork", [](const py::args& args, const py::kwargs& kwargs) {
2172     AT_ASSERT(!args.empty());
2173 
2174     py::function f = py::cast<py::function>(args[0]);
2175     py::tuple args_tup(args.size() - 1);
2176 
2177     for (const auto i : c10::irange(1, args.size())) {
2178       args_tup[i - 1] = args[i];
2179     }
2180 
2181     if (jit::tracer::isTracing()) {
2182       auto graph = jit::tracer::getTracingState()->graph;
2183       auto fork_node = graph->insertNode(graph->create(prim::TracedFork, 1));
2184       auto body_block = fork_node->addBlock();
2185 
2186       Value* node_output = nullptr;
2187       py::object py_func_output;
2188       // Insert new trace ops into the fork op's sub-block
2189       WithInsertPoint guard(body_block);
2190       IValue output_ivalue;
2191       {
2192         tracer::WithNestedTracingFrame env_guard;
2193 
2194         // Run the user-supplied function
2195         py_func_output = f(*args_tup, **kwargs);
2196 
2197         // Convert the output of the user-supplied function to IValue. The type
2198         // information of this IValue is used both to record the correct type in
2199         // the trace.
2200         output_ivalue = toTypeInferredIValue(py_func_output);
2201         Value* out_val = jit::tracer::getValueTrace(output_ivalue);
2202         body_block->registerOutput(out_val);
2203         node_output =
2204             fork_node->output()->setType(FutureType::create(out_val->type()));
2205       }
2206 
2207       auto retval =
2208           c10::make_intrusive<c10::ivalue::Future>(output_ivalue.type());
2209 
2210       // Record the ivalue in the tracer
2211       jit::tracer::setValueTrace(retval, node_output);
2212 
2213       // stuff the ivalue output in the Future
2214       retval->markCompleted(output_ivalue);
2215 
2216       return std::make_shared<PythonFutureWrapper>(retval);
2217     } else {
2218       auto result = toTypeInferredIValue(f(*args_tup, **kwargs));
2219       auto retval = c10::make_intrusive<c10::ivalue::Future>(result.type());
2220       retval->markCompleted(std::move(result));
2221       return std::make_shared<PythonFutureWrapper>(retval);
2222     }
2223   });
2224 
2225   m.def("wait", [](const std::shared_ptr<PythonFutureWrapper>& fut) {
2226     TORCH_CHECK(fut, "Future can't be None");
2227     return fut->wait();
2228   });
2229 
2230   m.def(
2231       "_collect_all",
2232       [](const std::vector<std::shared_ptr<jit::PythonFutureWrapper>>& futures)
2233           -> std::shared_ptr<jit::PythonFutureWrapper> {
2234         auto typePtr = futures.empty() || futures[0] == nullptr
2235             ? AnyType::get()
2236             : futures[0]->fut->elementType();
2237         c10::List<c10::intrusive_ptr<c10::ivalue::Future>> asList(
2238             c10::FutureType::create(typePtr));
2239         asList.reserve(futures.size());
2240         for (const auto& f : futures) {
2241           TORCH_CHECK(f, "Future can't be None");
2242           asList.push_back(f->fut);
2243         }
2244         return std::make_shared<jit::PythonFutureWrapper>(
2245             c10::collectAll(asList),
2246             /* unwrap_func */ [futures](const py::object& /*unused*/) {
2247               // Throw errors when calling wait() on the returned Future if
2248               // any of the original futures would throw.
2249               // NB: PythonFutureWrapper takes an unwrap_func which serves as a
2250               // callback to evalute the value in the Future. RPC uses this
2251               // unwrap_func to check whether the returned py::object is a
2252               // RemoteException object, and re-throw the exception if it is.
2253               // By extracting the c10::ivalue::Future from PythonFutureWrapper
2254               // the unwrap_func on the original PythonFutureWrapper objects are
2255               // discarded, and hence it will return the RemoteException as an
2256               // object instead of re-throwing it.
2257               for (auto& fut : futures) {
2258                 fut->wait();
2259               }
2260             });
2261       },
2262       py::call_guard<py::gil_scoped_release>());
2263 
2264   m.def("_jit_assert_is_instance", [](py::object obj, const TypePtr& type) {
2265     toIValue(std::move(obj), type);
2266   });
2267 
2268 #if defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS)
2269   m.def("_set_print_stack_traces_on_fatal_signal", [](bool print) {
2270     c10::FatalSignalHandler::getInstance().setPrintStackTracesOnFatalSignal(
2271         print);
2272   });
2273 #endif // defined(C10_SUPPORTS_SIGNAL_HANDLER)
2274 
2275   initPythonCustomClassBindings(module);
2276   initPythonIRBindings(module);
2277   tracer::initPythonTracerBindings(module);
2278   initTreeViewBindings(module);
2279   initJitScriptBindings(module);
2280   initJitBackendBindings(module);
2281   initStaticModuleBindings(module);
2282   initTensorExprBindings(module);
2283   // initNvFuserPythonBindings(module);
2284 
2285   setPrintHandler([](const std::string& str) {
2286     py::gil_scoped_acquire acquire;
2287     try {
2288       auto _stdout = py::module::import("sys").attr("stdout");
2289       _stdout.attr("write")(str);
2290     } catch (py::error_already_set& e) {
2291       throw std::runtime_error(e.what());
2292     }
2293   });
2294 
2295   // On exit we need to reset the print handler to default one,
2296   // because otherwise prim::Print() instruction won't work for JIT modules.
2297   auto atexit = py::module_::import("atexit");
2298   atexit.attr("register")(
2299       py::cpp_function([]() { setPrintHandler(getDefaultPrintHandler()); }));
2300 }
2301 
2302 } // namespace torch::jit
2303