xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/onnx.h>
2 
3 #include <ATen/core/functional.h>
4 #include <c10/util/Exception.h>
5 #include <c10/util/irange.h>
6 #include <torch/csrc/autograd/function.h>
7 #include <torch/csrc/autograd/symbolic.h>
8 #include <torch/csrc/jit/ir/constants.h>
9 #include <torch/csrc/jit/jit_log.h>
10 #include <torch/csrc/jit/passes/dead_code_elimination.h>
11 #include <torch/csrc/jit/passes/onnx/constant_map.h>
12 #include <torch/csrc/jit/passes/onnx/helper.h>
13 #include <torch/csrc/jit/passes/onnx/onnx_log.h>
14 #include <torch/csrc/jit/passes/onnx/shape_type_inference.h>
15 #include <torch/csrc/jit/python/python_ir.h>
16 #include <torch/csrc/utils/pybind.h>
17 #include <sstream>
18 
19 namespace torch::jit {
20 
removePrintOps(Block * block)21 void removePrintOps(Block* block) {
22   for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
23        ++it) {
24     for (auto b : it->blocks()) {
25       removePrintOps(b);
26     }
27     if (it->kind() == prim::Print || it->kind() == aten::warn) {
28       for (size_t i = 0; i < it->inputs().size();) {
29         auto input = it->inputs().at(i);
30         // only handling constants bc of potential side effects
31         if (input->uses().size() == 1 &&
32             input->node()->kind() == prim::Constant) {
33           it->removeInput(i);
34           input->node()->destroy();
35         } else {
36           ++i;
37         }
38       }
39       it.destroyCurrent();
40     }
41   }
42 }
43 
RemovePrintOps(std::shared_ptr<Graph> & graph)44 void RemovePrintOps(std::shared_ptr<Graph>& graph) {
45   removePrintOps(graph->block());
46   GRAPH_DUMP("After RemovePrintOps: ", graph);
47 }
48 
checkONNXCompatibility(const c10::FunctionSchema & schema)49 void checkONNXCompatibility(const c10::FunctionSchema& schema) {
50   // in ONNX, all inputs are tensors, no support for tensor list
51   // so at most one input tensor list is supported
52   bool has_tensor_list = false;
53   const auto& args = schema.arguments();
54   for (const auto& arg : args) {
55     if (arg.name() == "_caffe2_preallocated_outputs") {
56       continue;
57     }
58     auto type = arg.type();
59     if (type->kind() == TypeKind::OptionalType) {
60       type = reinterpret_cast<OptionalType*>(type.get())->getElementType();
61       // recursive optional type is not supported
62       TORCH_INTERNAL_ASSERT(type->kind() != TypeKind::OptionalType);
63     }
64     if (type->kind() == TypeKind::ListType) {
65       const auto& elem_type =
66           reinterpret_cast<ListType*>(type.get())->getElementType();
67       if (elem_type->isSubtypeOf(*TensorType::get())) {
68         TORCH_INTERNAL_ASSERT(
69             !has_tensor_list,
70             "ONNX export supports at most one TensorList as input.");
71         has_tensor_list = true;
72       }
73     }
74   }
75 }
76 
preprocessCaffe2Ops(Block * block)77 void preprocessCaffe2Ops(Block* block) {
78   for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
79        ++it) {
80     for (auto b : it->blocks()) {
81       preprocessCaffe2Ops(b);
82     }
83     if (it->kind().is_caffe2()) {
84       const auto& schema = it->schema();
85       checkONNXCompatibility(schema);
86       std::vector<Value*> origin_inputs;
87       for (Value* v : it->inputs()) {
88         origin_inputs.push_back(v);
89       }
90       it->removeAllInputs();
91       const auto& args = schema.arguments();
92       size_t origin_inputs_index = 0;
93       for (const auto& arg : args) {
94         const auto& type = arg.type();
95         TORCH_INTERNAL_ASSERT(origin_inputs_index < origin_inputs.size());
96         const auto& origin_input = origin_inputs[origin_inputs_index++];
97         if (type->kind() == TypeKind::OptionalType &&
98             origin_input->mustBeNone()) {
99           continue;
100         }
101         if (type->isSubtypeOf(*TensorType::get())) {
102           it->addInput(origin_input);
103         } else if (
104             type->kind() == TypeKind::BoolType ||
105             type->kind() == TypeKind::IntType) {
106           const auto* constant_node = origin_input->node();
107           TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant);
108           it->i_(Symbol::attr(arg.name()), constant_node->i(attr::value));
109         } else if (type->kind() == TypeKind::FloatType) {
110           const auto* constant_node = origin_input->node();
111           TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant);
112           it->f_(Symbol::attr(arg.name()), constant_node->f(attr::value));
113         } else if (type->kind() == TypeKind::StringType) {
114           const auto* constant_node = origin_input->node();
115           TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant);
116           it->s_(Symbol::attr(arg.name()), constant_node->s(attr::value));
117         } else if (type->kind() == TypeKind::ListType) {
118           const auto& list_node = origin_input->node();
119           const auto& elem_type = type->castRaw<ListType>()->getElementType();
120           TORCH_INTERNAL_ASSERT(
121               list_node->kind() == prim::ListConstruct ||
122               list_node->kind() == prim::Constant);
123           if (elem_type->isSubtypeOf(*TensorType::get())) {
124             TORCH_INTERNAL_ASSERT(list_node->kind(), prim::ListConstruct);
125             const auto& tensor_list = origin_input->node()->inputs();
126             for (const auto& t : tensor_list) {
127               it->addInput(t);
128             }
129           } else if (elem_type->kind() == TypeKind::FloatType) {
130             std::vector<double> values;
131             if (list_node->kind() == prim::ListConstruct) {
132               for (const auto* elem_input : list_node->inputs()) {
133                 const auto* constant_node = elem_input->node();
134                 TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant);
135                 values.push_back(constant_node->f(attr::value));
136               }
137             } else { // is a constant list
138               values = list_node->fs(attr::value);
139             }
140             it->fs_(Symbol::attr(arg.name()), values);
141           } else {
142             throw std::runtime_error(
143                 "Unhandled scalar arg: " + arg.name() +
144                 ", type: " + c10::typeKindToString(elem_type->kind()));
145           }
146         } else {
147           throw std::runtime_error(
148               "Unsupported input type of arg " + arg.name() +
149               " in Caffe2 operator: " + c10::typeKindToString(type->kind()));
150         }
151       }
152     }
153   }
154   EliminateDeadCode(
155       block, true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
156 }
157 
PreprocessCaffe2Ops(std::shared_ptr<Graph> & graph)158 void PreprocessCaffe2Ops(std::shared_ptr<Graph>& graph) {
159   preprocessCaffe2Ops(graph->block());
160   GRAPH_DUMP("After PreprocessCaffe2Ops: ", graph);
161 }
162 
163 // Transform PythonOps into Nodes that match ONNX semantics.
ToONNX(std::shared_ptr<Graph> & graph,::torch::onnx::OperatorExportTypes operator_export_type)164 std::shared_ptr<Graph> ToONNX(
165     std::shared_ptr<Graph>& graph,
166     ::torch::onnx::OperatorExportTypes operator_export_type) {
167   ConstantValueMap::ClearMaps();
168   auto new_graph = std::make_shared<Graph>(graph->current_scope());
169   py::dict env;
170   // Kept identical to values in env. Used for constant-time existance check.
171   py::set values_in_env;
172   try {
173     BlockToONNX(
174         graph->block(),
175         new_graph->block(),
176         operator_export_type,
177         env,
178         values_in_env);
179   } catch (std::runtime_error& ex) {
180     ONNX_LOG(
181         "ONNX graph being constructed during exception:\n",
182         new_graph->toString());
183     throw;
184   }
185   GRAPH_DUMP("after ToONNX: ", new_graph);
186   ConstantValueMap::ClearMaps();
187   return new_graph;
188 }
189 
190 // BlockToONNX.
191 // is_sub_block = true means the old_block (aten graph) is in the sub block
192 // (e.g., if sub block), and we want to convert it into its parent block in onnx
193 // graph. In this case, we don't register the input/output or eliminate the dead
194 // code.
BlockToONNX(Block * old_block,Block * new_block,::torch::onnx::OperatorExportTypes operator_export_type,py::dict & env,py::set & values_in_env,bool is_sub_block)195 py::dict BlockToONNX(
196     Block* old_block,
197     Block* new_block,
198     ::torch::onnx::OperatorExportTypes operator_export_type,
199     py::dict& env,
200     py::set& values_in_env,
201     bool is_sub_block) {
202   torch::autograd::SymbolicContext ctx{};
203   ctx.block = new_block;
204 
205   GRAPH_DEBUG(
206       "BlockToONNX: graph of old block: ",
207       old_block->owningGraph()->toString());
208 
209   // Initialize context and environment
210   if (!is_sub_block) {
211     for (auto input : old_block->inputs()) {
212       auto n = ctx.block->addInput()->copyMetadata(input);
213       auto py_n = py::cast(n);
214       env[py::cast(input)] = py_n;
215       values_in_env.add(py_n);
216     }
217   }
218 
219   // Determine if all inputs are static. This is used for each node to
220   // determine whether or not to propagate shapes.
221   if (!is_sub_block) {
222     bool static_input_shape = AllGraphInputsStatic(ctx.block->owningGraph());
223     ConstantValueMap::SetAllGraphInputsStatic(static_input_shape);
224   }
225 
226   // Finally, visit all nodes in the graph
227   for (auto node : old_block->nodes()) {
228     NodeToONNX(node, ctx.block, operator_export_type, env, values_in_env);
229   }
230 
231   if (is_sub_block) {
232     return env;
233   }
234 
235   for (auto output : old_block->outputs()) {
236     auto py_value = env[py::cast(output)];
237     Value* value = py_value.cast<Value*>();
238     ctx.block->registerOutput(value);
239   }
240   // Run dce to clean-up unused functional and inplace ops.
241   EliminateDeadCode(
242       ctx.block,
243       true,
244       DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
245 
246   return py::dict();
247 }
248 
ConstantFoldCondition(torch::jit::Value * output)249 bool ConstantFoldCondition(torch::jit::Value* output) {
250   auto fold_condition = output->node()->kind() != c10::onnx::Constant &&
251       ConstantValueMap::HasValue(output->debugName());
252   auto reliable_value =
253       ConstantValueMap::GetTypeReliable(output->debugName()).value_or(false);
254   return fold_condition && reliable_value;
255 }
256 
NodeToONNX(Node * old_node,Block * new_block,::torch::onnx::OperatorExportTypes operator_export_type,py::dict & env,py::set & values_in_env)257 void NodeToONNX(
258     Node* old_node,
259     Block* new_block,
260     ::torch::onnx::OperatorExportTypes operator_export_type,
261     py::dict& env,
262     py::set& values_in_env) {
263   py::object onnx = py::module::import("torch.onnx");
264   py::object onnx_globals = py::module::import("torch.onnx._globals");
265   py::object onnx_registration =
266       py::module::import("torch.onnx._internal.registration");
267 
268   // Setup all the lambda helper functions.
269 
270   // Returns a node that n maps to in the new graph
271   auto envFn = [&env](Value* n) -> Value* {
272     auto py_n = py::cast(n);
273     TORCH_CHECK(env.contains(py_n), "Dangling node reference");
274     auto py_value = env[py_n];
275     TORCH_CHECK(!py_value.is_none(), "Unused node was subsequently used");
276     Value* value = py_value.cast<Value*>();
277     return value;
278   };
279 
280   // Put the new outputs in our environment map, and copy the type from the
281   // input graph if they were not set by the symbolic. This is called only
282   // with results of symbolic call (not for nodes that are just cloned).
283   auto setOutputs = [&](const std::string& op_name,
284                         Node* node,
285                         const value_list& outputs) {
286     auto old_outputs = node->outputs();
287     // Count all outputs, excluding Handles
288     auto num_old_outputs = old_outputs.size();
289     if (outputs.size() != num_old_outputs) {
290       std::ostringstream ss;
291       ss << "symbolic for " << op_name
292          << " produced an incorrect number of outputs (expected ";
293       ss << num_old_outputs << ", but got " << outputs.size() << ")";
294       throw std::runtime_error(ss.str());
295     }
296     // For const node, it does not need params_dict info, so set it to {}.
297     const ParamMap empty_params_dict = {};
298     auto opset_version = py::cast<int>(
299         onnx_globals.attr("GLOBALS").attr("export_onnx_opset_version"));
300     for (const auto i : c10::irange(num_old_outputs)) {
301       auto old = old_outputs[i];
302       if (outputs[i]) {
303         bool exist_in_env = values_in_env.contains(py::cast(outputs[i]));
304         // Update ONNX value debug name with ATen value debug name if existed.
305         // Skip if ONNX value already exist in environment.
306         // This implies the op is a noop, and the value is owned by
307         // other node created elsewhere.
308         if (old->hasDebugName() && !exist_in_env) {
309           auto old_name = outputs[i]->debugName();
310           auto new_name = old->debugNameBase();
311           Value* found_value = nullptr;
312           bool exists = false;
313           // In this scope, we fetch debug_names as a const reference and then
314           // construct an iterator exist_name based on it. This iterator will
315           // be corrupted if the underlying map of debug_names changes. This
316           // will happen as a side-effect of setDebugName. For these reasons,
317           // we make an explicit scope for exist_name and make sure that
318           // setDebugName is never called with this scope.
319           {
320             const auto& debug_names = new_block->owningGraph()->debugNames();
321             auto exist_name = debug_names.find(new_name);
322             exists = exist_name != debug_names.end();
323             if (exists) {
324               found_value = exist_name->second;
325             }
326           }
327           outputs[i]->setDebugName(new_name);
328           if (exists) {
329             found_value->setDebugName(new_name);
330           }
331           ConstantValueMap::UpdateValueName(old_name, outputs[i]->debugName());
332         }
333         // Allow symbolic() to skip specifying the type of the return node.
334         // Unfortunately, they are on the hook for all internal nodes
335         // (though in practice, the types are not computed.)
336         //
337         // If onnx shape inference is turned on, the new outputs will have
338         // types inferred, and they will be merged with the old types.
339         if (ConstantFoldCondition(outputs[i])) {
340           // Create a const node if the node output value is in
341           // ConstantValueMap.
342           auto value =
343               ConstantValueMap::GetValue(outputs[i]->debugName()).value();
344           Node* const_node =
345               new_block->owningGraph()->create(c10::onnx::Constant);
346           const_node->t_(attr::value, value);
347           const_node->output()->setType(TensorType::create(value));
348 
349           // Copy over source location and scope information to all nodes
350           // created by the symbolic
351           const_node->copyMetadata(node);
352           new_block->appendNode(const_node);
353           ONNXShapeTypeInference(const_node, empty_params_dict, opset_version);
354           auto py_output = py::cast(const_node->output());
355           env[py::cast(old)] = py_output;
356           values_in_env.add(py_output);
357         } else {
358           // An update in ConstantValueMap is also needed here, since
359           // the user setType can be only accessed in this step, and it
360           // should be reliable.
361           MergeInferredTypeAndSetMap(
362               outputs[i], old->type(), outputs[i]->type());
363           // non ONNX node with no type given will throw out the warnings here.
364           UpdateReliable(
365               outputs[i],
366               AreInputsReliableOrStatic(outputs[i]->node()),
367               /*no_type_warning=*/true);
368           // For the node type that does not have ComputeConstant logic, it may
369           // have reliable shape but its shape is not in ConstantValueMap. So we
370           // need to update ConstantValueMap.
371           UpdateShapeConstantIfReliable(outputs[i]);
372 
373           // Copy over source location and scope information to all nodes
374           // created by the symbolic
375           // Do not set metadata if outputs[i] is already in env.
376           if (!exist_in_env) {
377             outputs[i]->node()->copyMetadata(node);
378           }
379           auto py_output = py::cast(outputs[i]);
380           env[py::cast(old)] = py_output;
381           values_in_env.add(py_output);
382         }
383       } else {
384         // Null output means that the ONNX op doesn't have outputs corresponding
385         // to certain PyTorch outputs
386         env[py::cast(old)] = py::none();
387         if (!old->uses().empty()) {
388           std::ostringstream ss;
389           ss << "symbolic for " << op_name << " returned None for the output "
390              << i;
391           ss << " (indicating conversion for that particular output is not supported), ";
392           ss << "but the network uses this output later";
393           // TODO: Say what actually used it
394           throw std::runtime_error(ss.str());
395         }
396       }
397     }
398   };
399 
400   // Clone the node and add it to the new graph
401   auto cloneNode = [&](Node* node) {
402     auto n_ = new_block->appendNode(
403         new_block->owningGraph()->createClone(node, envFn));
404     for (const auto i : c10::irange(node->outputs().size())) {
405       // n_->outputs()[i]->setType(node->outputs()[i]->type());
406       auto py_output = py::cast(n_->output(i));
407       env[py::cast(node->output(i))] = py_output;
408       values_in_env.add(py_output);
409     }
410   };
411 
412   // Inline the prim::PythonOp sub-block nodes and append them to the onnx graph
413   auto inlineAutograd = [&](Node* PythonOpNode) {
414     for (auto subblock : PythonOpNode->blocks()) {
415       for (const auto i : c10::irange(PythonOpNode->inputs().size())) {
416         auto py_value = env[py::cast(PythonOpNode->inputs()[i])];
417         env[py::cast(subblock->inputs()[i])] = py_value;
418         values_in_env.add(py_value);
419       }
420       for (auto* node : subblock->nodes()) {
421         NodeToONNX(node, new_block, operator_export_type, env, values_in_env);
422       }
423       for (const auto i : c10::irange(PythonOpNode->outputs().size())) {
424         auto py_value = env[py::cast(subblock->outputs()[i])];
425         env[py::cast(PythonOpNode->outputs()[i])] = py_value;
426         values_in_env.add(py_value);
427       }
428     }
429   };
430 
431   // Cast output of symbolic() python implementation
432   auto processSymbolicOutput = [&](const std::string& op_name,
433                                    Node* n,
434                                    const py::object& raw_output) {
435     if (raw_output.ptr() == Py_None) {
436       cloneNode(n);
437       return;
438     }
439     // Cast the outputs back to C++ and put them in the new graph
440     std::vector<Value*> outputs;
441     try {
442       if (py::isinstance<Value>(raw_output)) {
443         outputs = value_list{py::cast<Value*>(raw_output)};
444       } else {
445         outputs = py::cast<std::vector<Value*>>(raw_output);
446       }
447     } catch (const std::exception& ex) {
448       std::ostringstream ss;
449       ss << "Error casting results of symbolic for " << op_name
450          << ": expected to return list of op nodes, instead received type ''"
451          << py::str(raw_output.get_type()) << "': " << py::str(raw_output);
452       throw std::runtime_error(ss.str());
453     }
454 
455     setOutputs(op_name, n, outputs);
456   };
457 
458   auto callPySymbolicFunction = [&](Node* n) {
459     // The idea is delegate as much of the actual argument massaging to
460     // Python as possible
461 
462     py::tuple py_inputs(n->inputs().size());
463     Py_ssize_t input_nr = 0;
464     for (auto* input : n->inputs()) {
465       py_inputs[input_nr++] = py::cast(envFn(input));
466     }
467 
468     Graph* g = new_block->owningGraph();
469 
470     WithInsertPoint insert_point_guard(new_block);
471     WithCurrentScope scope_guard(*g, n->scope());
472 
473     // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to
474     // Python. Check #87343 for details.
475     py::list new_nodes = py::list();
476     py::object raw_output = onnx.attr("_run_symbolic_function")(
477         g->shared_from_this(),
478         new_block,
479         n,
480         py_inputs,
481         env,
482         values_in_env,
483         new_nodes,
484         operator_export_type);
485 
486     // Find new nodes that have been created by _run_symbolic_function and
487     // propagate metadata
488     for (py::handle py_node : new_nodes) {
489       Node* node = py_node.cast<Node*>();
490       node->copyMetadata(n);
491     }
492 
493     // TODO: Assert it's an ATen identifier???
494     // (Sometimes it's not...)
495     processSymbolicOutput(n->kind().toUnqualString(), n, raw_output);
496     GRAPH_DUMP("after processSymbolicOutput: ", g);
497   };
498 
499   auto callPySymbolicMethod = [&](ConcretePythonOp* op) {
500     // Test if there is a symbolic function; bail if there is not
501     auto pyobj = py::handle(op->pyobj.get());
502     auto func = op->autogradFunction();
503     if (func) {
504       pyobj = func->get();
505     }
506 
507     py::object opset_version =
508         onnx_globals.attr("GLOBALS").attr("export_onnx_opset_version");
509     // NOTE(justinchuby): Call the internal registry to register the symbolic
510     // method defined in the module.
511     bool is_registered_op =
512         onnx_registration.attr("registry")
513             .attr("is_registered_op")("prim::PythonOp", opset_version)
514             .cast<bool>();
515     py::bool_ is_autograd_inlining_enabled =
516         py::cast<bool>(onnx_globals.attr("GLOBALS").attr("autograd_inlining"));
517     if (!py::hasattr(pyobj, "symbolic") && !is_registered_op) {
518       // Inline the subgraph within the prim::PythonOp unless
519       // either of these conditions are satisfied
520       // 1. The torch.autograd.Function class of this node object has `symbolic`
521       // method defined.
522       // 2. Custom export symbolic is registered for prim::PythonOp.
523       if ((operator_export_type == ::torch::onnx::OperatorExportTypes::ONNX ||
524            operator_export_type ==
525                ::torch::onnx::OperatorExportTypes::ONNX_ATEN_FALLBACK) &&
526           (py::cast<bool>(is_autograd_inlining_enabled))) {
527         try {
528           inlineAutograd(op);
529         } catch (const std::exception& ex) {
530           TORCH_WARN(
531               "Unable to inline PythonOp: ",
532               op->name(),
533               " due to the following exception\n",
534               ex.what(),
535               "prim::PythonOp will be exported as is and without being inlined\n",
536               "Try exporting with the following alternatives: \n",
537               "1) Set operator_export_type to ONNX_FALLTHROUGH mode\n",
538               "2) Register a symbolic method for the prim::PythonOp ",
539               op->name());
540           cloneNode(op);
541         }
542       } else {
543         cloneNode(op);
544       }
545       return;
546     }
547 
548     // Prepare args for Python. First one is the graph, and is followed
549     // by regular args, with Variables replaced by corresponding nodes.
550     Py_ssize_t input_nr = 0;
551     py::tuple py_symbolic_args(op->cconv.size());
552     auto inputs = op->inputs();
553     auto node_it = inputs.begin();
554     auto scalar_it = op->scalar_args.begin();
555     for (auto arg_type : op->cconv) {
556       py::object obj;
557       if (arg_type == 'c') {
558         TORCH_CHECK(
559             scalar_it != op->scalar_args.end(),
560             "expected too many scalar args");
561         obj = py::reinterpret_borrow<py::object>(
562             py::handle((scalar_it++)->get()));
563       } else if (arg_type == 'd') {
564         TORCH_CHECK(node_it != inputs.end(), "expected too many inputs");
565         obj = py::cast(envFn(*node_it++));
566       } else {
567         throw std::runtime_error("unexpected calling convention");
568       }
569       py_symbolic_args[input_nr++] = obj;
570     }
571 
572     WithInsertPoint insert_point_guard(new_block);
573     WithCurrentScope scope_guard(*new_block->owningGraph(), op->scope());
574 
575     if (py::hasattr(pyobj, "symbolic")) {
576       // Call the symbolic function
577       // Use a little trampoline function so we can give good error messages
578       // upon argument mismatch
579       // Register as a custom operator
580       // TODO: Find a more elegant way to do this without having to touch
581       // internal Python modules.
582       // TODO(justinchuby): Define a namespace for these Python Ops.
583       onnx_registration.attr("registry")
584           .attr("register")(
585               "::" + op->name(),
586               opset_version,
587               pyobj.attr("symbolic"),
588               /* custom */ true);
589 
590       // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to
591       // Python. Check #87343 for details.
592       py::object raw_output = onnx.attr("_run_symbolic_method")(
593           new_block->owningGraph()->shared_from_this(),
594           op->name(),
595           pyobj.attr("symbolic"),
596           py_symbolic_args);
597 
598       processSymbolicOutput(op->name(), op, raw_output);
599     } else {
600       TORCH_INTERNAL_ASSERT(is_registered_op);
601       Node* n = static_cast<Node*>(op);
602       n->s_(attr::name, op->name());
603       // Call symbolic function
604       // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to
605       // Python. Check #87343 for details.
606       py::list new_nodes = py::list();
607       py::object raw_output = onnx.attr("_run_symbolic_function")(
608           new_block->owningGraph()->shared_from_this(),
609           new_block,
610           n,
611           py_symbolic_args,
612           env,
613           values_in_env,
614           new_nodes,
615           operator_export_type);
616 
617       processSymbolicOutput(op->kind().toUnqualString(), n, raw_output);
618     }
619   };
620 
621   auto k = old_node->kind();
622   if (k.is_caffe2()) {
623     // Pass on Caffe2 operator, since we already preprocess it
624     cloneNode(old_node);
625   } else if (k == prim::PythonOp) {
626     callPySymbolicMethod(static_cast<ConcretePythonOp*>(old_node));
627   } else {
628     callPySymbolicFunction(old_node);
629   }
630 }
631 
632 } // namespace torch::jit
633