1 #include <torch/csrc/jit/passes/onnx.h>
2
3 #include <ATen/core/functional.h>
4 #include <c10/util/Exception.h>
5 #include <c10/util/irange.h>
6 #include <torch/csrc/autograd/function.h>
7 #include <torch/csrc/autograd/symbolic.h>
8 #include <torch/csrc/jit/ir/constants.h>
9 #include <torch/csrc/jit/jit_log.h>
10 #include <torch/csrc/jit/passes/dead_code_elimination.h>
11 #include <torch/csrc/jit/passes/onnx/constant_map.h>
12 #include <torch/csrc/jit/passes/onnx/helper.h>
13 #include <torch/csrc/jit/passes/onnx/onnx_log.h>
14 #include <torch/csrc/jit/passes/onnx/shape_type_inference.h>
15 #include <torch/csrc/jit/python/python_ir.h>
16 #include <torch/csrc/utils/pybind.h>
17 #include <sstream>
18
19 namespace torch::jit {
20
removePrintOps(Block * block)21 void removePrintOps(Block* block) {
22 for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
23 ++it) {
24 for (auto b : it->blocks()) {
25 removePrintOps(b);
26 }
27 if (it->kind() == prim::Print || it->kind() == aten::warn) {
28 for (size_t i = 0; i < it->inputs().size();) {
29 auto input = it->inputs().at(i);
30 // only handling constants bc of potential side effects
31 if (input->uses().size() == 1 &&
32 input->node()->kind() == prim::Constant) {
33 it->removeInput(i);
34 input->node()->destroy();
35 } else {
36 ++i;
37 }
38 }
39 it.destroyCurrent();
40 }
41 }
42 }
43
RemovePrintOps(std::shared_ptr<Graph> & graph)44 void RemovePrintOps(std::shared_ptr<Graph>& graph) {
45 removePrintOps(graph->block());
46 GRAPH_DUMP("After RemovePrintOps: ", graph);
47 }
48
checkONNXCompatibility(const c10::FunctionSchema & schema)49 void checkONNXCompatibility(const c10::FunctionSchema& schema) {
50 // in ONNX, all inputs are tensors, no support for tensor list
51 // so at most one input tensor list is supported
52 bool has_tensor_list = false;
53 const auto& args = schema.arguments();
54 for (const auto& arg : args) {
55 if (arg.name() == "_caffe2_preallocated_outputs") {
56 continue;
57 }
58 auto type = arg.type();
59 if (type->kind() == TypeKind::OptionalType) {
60 type = reinterpret_cast<OptionalType*>(type.get())->getElementType();
61 // recursive optional type is not supported
62 TORCH_INTERNAL_ASSERT(type->kind() != TypeKind::OptionalType);
63 }
64 if (type->kind() == TypeKind::ListType) {
65 const auto& elem_type =
66 reinterpret_cast<ListType*>(type.get())->getElementType();
67 if (elem_type->isSubtypeOf(*TensorType::get())) {
68 TORCH_INTERNAL_ASSERT(
69 !has_tensor_list,
70 "ONNX export supports at most one TensorList as input.");
71 has_tensor_list = true;
72 }
73 }
74 }
75 }
76
preprocessCaffe2Ops(Block * block)77 void preprocessCaffe2Ops(Block* block) {
78 for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
79 ++it) {
80 for (auto b : it->blocks()) {
81 preprocessCaffe2Ops(b);
82 }
83 if (it->kind().is_caffe2()) {
84 const auto& schema = it->schema();
85 checkONNXCompatibility(schema);
86 std::vector<Value*> origin_inputs;
87 for (Value* v : it->inputs()) {
88 origin_inputs.push_back(v);
89 }
90 it->removeAllInputs();
91 const auto& args = schema.arguments();
92 size_t origin_inputs_index = 0;
93 for (const auto& arg : args) {
94 const auto& type = arg.type();
95 TORCH_INTERNAL_ASSERT(origin_inputs_index < origin_inputs.size());
96 const auto& origin_input = origin_inputs[origin_inputs_index++];
97 if (type->kind() == TypeKind::OptionalType &&
98 origin_input->mustBeNone()) {
99 continue;
100 }
101 if (type->isSubtypeOf(*TensorType::get())) {
102 it->addInput(origin_input);
103 } else if (
104 type->kind() == TypeKind::BoolType ||
105 type->kind() == TypeKind::IntType) {
106 const auto* constant_node = origin_input->node();
107 TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant);
108 it->i_(Symbol::attr(arg.name()), constant_node->i(attr::value));
109 } else if (type->kind() == TypeKind::FloatType) {
110 const auto* constant_node = origin_input->node();
111 TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant);
112 it->f_(Symbol::attr(arg.name()), constant_node->f(attr::value));
113 } else if (type->kind() == TypeKind::StringType) {
114 const auto* constant_node = origin_input->node();
115 TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant);
116 it->s_(Symbol::attr(arg.name()), constant_node->s(attr::value));
117 } else if (type->kind() == TypeKind::ListType) {
118 const auto& list_node = origin_input->node();
119 const auto& elem_type = type->castRaw<ListType>()->getElementType();
120 TORCH_INTERNAL_ASSERT(
121 list_node->kind() == prim::ListConstruct ||
122 list_node->kind() == prim::Constant);
123 if (elem_type->isSubtypeOf(*TensorType::get())) {
124 TORCH_INTERNAL_ASSERT(list_node->kind(), prim::ListConstruct);
125 const auto& tensor_list = origin_input->node()->inputs();
126 for (const auto& t : tensor_list) {
127 it->addInput(t);
128 }
129 } else if (elem_type->kind() == TypeKind::FloatType) {
130 std::vector<double> values;
131 if (list_node->kind() == prim::ListConstruct) {
132 for (const auto* elem_input : list_node->inputs()) {
133 const auto* constant_node = elem_input->node();
134 TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant);
135 values.push_back(constant_node->f(attr::value));
136 }
137 } else { // is a constant list
138 values = list_node->fs(attr::value);
139 }
140 it->fs_(Symbol::attr(arg.name()), values);
141 } else {
142 throw std::runtime_error(
143 "Unhandled scalar arg: " + arg.name() +
144 ", type: " + c10::typeKindToString(elem_type->kind()));
145 }
146 } else {
147 throw std::runtime_error(
148 "Unsupported input type of arg " + arg.name() +
149 " in Caffe2 operator: " + c10::typeKindToString(type->kind()));
150 }
151 }
152 }
153 }
154 EliminateDeadCode(
155 block, true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
156 }
157
PreprocessCaffe2Ops(std::shared_ptr<Graph> & graph)158 void PreprocessCaffe2Ops(std::shared_ptr<Graph>& graph) {
159 preprocessCaffe2Ops(graph->block());
160 GRAPH_DUMP("After PreprocessCaffe2Ops: ", graph);
161 }
162
163 // Transform PythonOps into Nodes that match ONNX semantics.
ToONNX(std::shared_ptr<Graph> & graph,::torch::onnx::OperatorExportTypes operator_export_type)164 std::shared_ptr<Graph> ToONNX(
165 std::shared_ptr<Graph>& graph,
166 ::torch::onnx::OperatorExportTypes operator_export_type) {
167 ConstantValueMap::ClearMaps();
168 auto new_graph = std::make_shared<Graph>(graph->current_scope());
169 py::dict env;
170 // Kept identical to values in env. Used for constant-time existance check.
171 py::set values_in_env;
172 try {
173 BlockToONNX(
174 graph->block(),
175 new_graph->block(),
176 operator_export_type,
177 env,
178 values_in_env);
179 } catch (std::runtime_error& ex) {
180 ONNX_LOG(
181 "ONNX graph being constructed during exception:\n",
182 new_graph->toString());
183 throw;
184 }
185 GRAPH_DUMP("after ToONNX: ", new_graph);
186 ConstantValueMap::ClearMaps();
187 return new_graph;
188 }
189
190 // BlockToONNX.
191 // is_sub_block = true means the old_block (aten graph) is in the sub block
192 // (e.g., if sub block), and we want to convert it into its parent block in onnx
193 // graph. In this case, we don't register the input/output or eliminate the dead
194 // code.
BlockToONNX(Block * old_block,Block * new_block,::torch::onnx::OperatorExportTypes operator_export_type,py::dict & env,py::set & values_in_env,bool is_sub_block)195 py::dict BlockToONNX(
196 Block* old_block,
197 Block* new_block,
198 ::torch::onnx::OperatorExportTypes operator_export_type,
199 py::dict& env,
200 py::set& values_in_env,
201 bool is_sub_block) {
202 torch::autograd::SymbolicContext ctx{};
203 ctx.block = new_block;
204
205 GRAPH_DEBUG(
206 "BlockToONNX: graph of old block: ",
207 old_block->owningGraph()->toString());
208
209 // Initialize context and environment
210 if (!is_sub_block) {
211 for (auto input : old_block->inputs()) {
212 auto n = ctx.block->addInput()->copyMetadata(input);
213 auto py_n = py::cast(n);
214 env[py::cast(input)] = py_n;
215 values_in_env.add(py_n);
216 }
217 }
218
219 // Determine if all inputs are static. This is used for each node to
220 // determine whether or not to propagate shapes.
221 if (!is_sub_block) {
222 bool static_input_shape = AllGraphInputsStatic(ctx.block->owningGraph());
223 ConstantValueMap::SetAllGraphInputsStatic(static_input_shape);
224 }
225
226 // Finally, visit all nodes in the graph
227 for (auto node : old_block->nodes()) {
228 NodeToONNX(node, ctx.block, operator_export_type, env, values_in_env);
229 }
230
231 if (is_sub_block) {
232 return env;
233 }
234
235 for (auto output : old_block->outputs()) {
236 auto py_value = env[py::cast(output)];
237 Value* value = py_value.cast<Value*>();
238 ctx.block->registerOutput(value);
239 }
240 // Run dce to clean-up unused functional and inplace ops.
241 EliminateDeadCode(
242 ctx.block,
243 true,
244 DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
245
246 return py::dict();
247 }
248
ConstantFoldCondition(torch::jit::Value * output)249 bool ConstantFoldCondition(torch::jit::Value* output) {
250 auto fold_condition = output->node()->kind() != c10::onnx::Constant &&
251 ConstantValueMap::HasValue(output->debugName());
252 auto reliable_value =
253 ConstantValueMap::GetTypeReliable(output->debugName()).value_or(false);
254 return fold_condition && reliable_value;
255 }
256
NodeToONNX(Node * old_node,Block * new_block,::torch::onnx::OperatorExportTypes operator_export_type,py::dict & env,py::set & values_in_env)257 void NodeToONNX(
258 Node* old_node,
259 Block* new_block,
260 ::torch::onnx::OperatorExportTypes operator_export_type,
261 py::dict& env,
262 py::set& values_in_env) {
263 py::object onnx = py::module::import("torch.onnx");
264 py::object onnx_globals = py::module::import("torch.onnx._globals");
265 py::object onnx_registration =
266 py::module::import("torch.onnx._internal.registration");
267
268 // Setup all the lambda helper functions.
269
270 // Returns a node that n maps to in the new graph
271 auto envFn = [&env](Value* n) -> Value* {
272 auto py_n = py::cast(n);
273 TORCH_CHECK(env.contains(py_n), "Dangling node reference");
274 auto py_value = env[py_n];
275 TORCH_CHECK(!py_value.is_none(), "Unused node was subsequently used");
276 Value* value = py_value.cast<Value*>();
277 return value;
278 };
279
280 // Put the new outputs in our environment map, and copy the type from the
281 // input graph if they were not set by the symbolic. This is called only
282 // with results of symbolic call (not for nodes that are just cloned).
283 auto setOutputs = [&](const std::string& op_name,
284 Node* node,
285 const value_list& outputs) {
286 auto old_outputs = node->outputs();
287 // Count all outputs, excluding Handles
288 auto num_old_outputs = old_outputs.size();
289 if (outputs.size() != num_old_outputs) {
290 std::ostringstream ss;
291 ss << "symbolic for " << op_name
292 << " produced an incorrect number of outputs (expected ";
293 ss << num_old_outputs << ", but got " << outputs.size() << ")";
294 throw std::runtime_error(ss.str());
295 }
296 // For const node, it does not need params_dict info, so set it to {}.
297 const ParamMap empty_params_dict = {};
298 auto opset_version = py::cast<int>(
299 onnx_globals.attr("GLOBALS").attr("export_onnx_opset_version"));
300 for (const auto i : c10::irange(num_old_outputs)) {
301 auto old = old_outputs[i];
302 if (outputs[i]) {
303 bool exist_in_env = values_in_env.contains(py::cast(outputs[i]));
304 // Update ONNX value debug name with ATen value debug name if existed.
305 // Skip if ONNX value already exist in environment.
306 // This implies the op is a noop, and the value is owned by
307 // other node created elsewhere.
308 if (old->hasDebugName() && !exist_in_env) {
309 auto old_name = outputs[i]->debugName();
310 auto new_name = old->debugNameBase();
311 Value* found_value = nullptr;
312 bool exists = false;
313 // In this scope, we fetch debug_names as a const reference and then
314 // construct an iterator exist_name based on it. This iterator will
315 // be corrupted if the underlying map of debug_names changes. This
316 // will happen as a side-effect of setDebugName. For these reasons,
317 // we make an explicit scope for exist_name and make sure that
318 // setDebugName is never called with this scope.
319 {
320 const auto& debug_names = new_block->owningGraph()->debugNames();
321 auto exist_name = debug_names.find(new_name);
322 exists = exist_name != debug_names.end();
323 if (exists) {
324 found_value = exist_name->second;
325 }
326 }
327 outputs[i]->setDebugName(new_name);
328 if (exists) {
329 found_value->setDebugName(new_name);
330 }
331 ConstantValueMap::UpdateValueName(old_name, outputs[i]->debugName());
332 }
333 // Allow symbolic() to skip specifying the type of the return node.
334 // Unfortunately, they are on the hook for all internal nodes
335 // (though in practice, the types are not computed.)
336 //
337 // If onnx shape inference is turned on, the new outputs will have
338 // types inferred, and they will be merged with the old types.
339 if (ConstantFoldCondition(outputs[i])) {
340 // Create a const node if the node output value is in
341 // ConstantValueMap.
342 auto value =
343 ConstantValueMap::GetValue(outputs[i]->debugName()).value();
344 Node* const_node =
345 new_block->owningGraph()->create(c10::onnx::Constant);
346 const_node->t_(attr::value, value);
347 const_node->output()->setType(TensorType::create(value));
348
349 // Copy over source location and scope information to all nodes
350 // created by the symbolic
351 const_node->copyMetadata(node);
352 new_block->appendNode(const_node);
353 ONNXShapeTypeInference(const_node, empty_params_dict, opset_version);
354 auto py_output = py::cast(const_node->output());
355 env[py::cast(old)] = py_output;
356 values_in_env.add(py_output);
357 } else {
358 // An update in ConstantValueMap is also needed here, since
359 // the user setType can be only accessed in this step, and it
360 // should be reliable.
361 MergeInferredTypeAndSetMap(
362 outputs[i], old->type(), outputs[i]->type());
363 // non ONNX node with no type given will throw out the warnings here.
364 UpdateReliable(
365 outputs[i],
366 AreInputsReliableOrStatic(outputs[i]->node()),
367 /*no_type_warning=*/true);
368 // For the node type that does not have ComputeConstant logic, it may
369 // have reliable shape but its shape is not in ConstantValueMap. So we
370 // need to update ConstantValueMap.
371 UpdateShapeConstantIfReliable(outputs[i]);
372
373 // Copy over source location and scope information to all nodes
374 // created by the symbolic
375 // Do not set metadata if outputs[i] is already in env.
376 if (!exist_in_env) {
377 outputs[i]->node()->copyMetadata(node);
378 }
379 auto py_output = py::cast(outputs[i]);
380 env[py::cast(old)] = py_output;
381 values_in_env.add(py_output);
382 }
383 } else {
384 // Null output means that the ONNX op doesn't have outputs corresponding
385 // to certain PyTorch outputs
386 env[py::cast(old)] = py::none();
387 if (!old->uses().empty()) {
388 std::ostringstream ss;
389 ss << "symbolic for " << op_name << " returned None for the output "
390 << i;
391 ss << " (indicating conversion for that particular output is not supported), ";
392 ss << "but the network uses this output later";
393 // TODO: Say what actually used it
394 throw std::runtime_error(ss.str());
395 }
396 }
397 }
398 };
399
400 // Clone the node and add it to the new graph
401 auto cloneNode = [&](Node* node) {
402 auto n_ = new_block->appendNode(
403 new_block->owningGraph()->createClone(node, envFn));
404 for (const auto i : c10::irange(node->outputs().size())) {
405 // n_->outputs()[i]->setType(node->outputs()[i]->type());
406 auto py_output = py::cast(n_->output(i));
407 env[py::cast(node->output(i))] = py_output;
408 values_in_env.add(py_output);
409 }
410 };
411
412 // Inline the prim::PythonOp sub-block nodes and append them to the onnx graph
413 auto inlineAutograd = [&](Node* PythonOpNode) {
414 for (auto subblock : PythonOpNode->blocks()) {
415 for (const auto i : c10::irange(PythonOpNode->inputs().size())) {
416 auto py_value = env[py::cast(PythonOpNode->inputs()[i])];
417 env[py::cast(subblock->inputs()[i])] = py_value;
418 values_in_env.add(py_value);
419 }
420 for (auto* node : subblock->nodes()) {
421 NodeToONNX(node, new_block, operator_export_type, env, values_in_env);
422 }
423 for (const auto i : c10::irange(PythonOpNode->outputs().size())) {
424 auto py_value = env[py::cast(subblock->outputs()[i])];
425 env[py::cast(PythonOpNode->outputs()[i])] = py_value;
426 values_in_env.add(py_value);
427 }
428 }
429 };
430
431 // Cast output of symbolic() python implementation
432 auto processSymbolicOutput = [&](const std::string& op_name,
433 Node* n,
434 const py::object& raw_output) {
435 if (raw_output.ptr() == Py_None) {
436 cloneNode(n);
437 return;
438 }
439 // Cast the outputs back to C++ and put them in the new graph
440 std::vector<Value*> outputs;
441 try {
442 if (py::isinstance<Value>(raw_output)) {
443 outputs = value_list{py::cast<Value*>(raw_output)};
444 } else {
445 outputs = py::cast<std::vector<Value*>>(raw_output);
446 }
447 } catch (const std::exception& ex) {
448 std::ostringstream ss;
449 ss << "Error casting results of symbolic for " << op_name
450 << ": expected to return list of op nodes, instead received type ''"
451 << py::str(raw_output.get_type()) << "': " << py::str(raw_output);
452 throw std::runtime_error(ss.str());
453 }
454
455 setOutputs(op_name, n, outputs);
456 };
457
458 auto callPySymbolicFunction = [&](Node* n) {
459 // The idea is delegate as much of the actual argument massaging to
460 // Python as possible
461
462 py::tuple py_inputs(n->inputs().size());
463 Py_ssize_t input_nr = 0;
464 for (auto* input : n->inputs()) {
465 py_inputs[input_nr++] = py::cast(envFn(input));
466 }
467
468 Graph* g = new_block->owningGraph();
469
470 WithInsertPoint insert_point_guard(new_block);
471 WithCurrentScope scope_guard(*g, n->scope());
472
473 // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to
474 // Python. Check #87343 for details.
475 py::list new_nodes = py::list();
476 py::object raw_output = onnx.attr("_run_symbolic_function")(
477 g->shared_from_this(),
478 new_block,
479 n,
480 py_inputs,
481 env,
482 values_in_env,
483 new_nodes,
484 operator_export_type);
485
486 // Find new nodes that have been created by _run_symbolic_function and
487 // propagate metadata
488 for (py::handle py_node : new_nodes) {
489 Node* node = py_node.cast<Node*>();
490 node->copyMetadata(n);
491 }
492
493 // TODO: Assert it's an ATen identifier???
494 // (Sometimes it's not...)
495 processSymbolicOutput(n->kind().toUnqualString(), n, raw_output);
496 GRAPH_DUMP("after processSymbolicOutput: ", g);
497 };
498
499 auto callPySymbolicMethod = [&](ConcretePythonOp* op) {
500 // Test if there is a symbolic function; bail if there is not
501 auto pyobj = py::handle(op->pyobj.get());
502 auto func = op->autogradFunction();
503 if (func) {
504 pyobj = func->get();
505 }
506
507 py::object opset_version =
508 onnx_globals.attr("GLOBALS").attr("export_onnx_opset_version");
509 // NOTE(justinchuby): Call the internal registry to register the symbolic
510 // method defined in the module.
511 bool is_registered_op =
512 onnx_registration.attr("registry")
513 .attr("is_registered_op")("prim::PythonOp", opset_version)
514 .cast<bool>();
515 py::bool_ is_autograd_inlining_enabled =
516 py::cast<bool>(onnx_globals.attr("GLOBALS").attr("autograd_inlining"));
517 if (!py::hasattr(pyobj, "symbolic") && !is_registered_op) {
518 // Inline the subgraph within the prim::PythonOp unless
519 // either of these conditions are satisfied
520 // 1. The torch.autograd.Function class of this node object has `symbolic`
521 // method defined.
522 // 2. Custom export symbolic is registered for prim::PythonOp.
523 if ((operator_export_type == ::torch::onnx::OperatorExportTypes::ONNX ||
524 operator_export_type ==
525 ::torch::onnx::OperatorExportTypes::ONNX_ATEN_FALLBACK) &&
526 (py::cast<bool>(is_autograd_inlining_enabled))) {
527 try {
528 inlineAutograd(op);
529 } catch (const std::exception& ex) {
530 TORCH_WARN(
531 "Unable to inline PythonOp: ",
532 op->name(),
533 " due to the following exception\n",
534 ex.what(),
535 "prim::PythonOp will be exported as is and without being inlined\n",
536 "Try exporting with the following alternatives: \n",
537 "1) Set operator_export_type to ONNX_FALLTHROUGH mode\n",
538 "2) Register a symbolic method for the prim::PythonOp ",
539 op->name());
540 cloneNode(op);
541 }
542 } else {
543 cloneNode(op);
544 }
545 return;
546 }
547
548 // Prepare args for Python. First one is the graph, and is followed
549 // by regular args, with Variables replaced by corresponding nodes.
550 Py_ssize_t input_nr = 0;
551 py::tuple py_symbolic_args(op->cconv.size());
552 auto inputs = op->inputs();
553 auto node_it = inputs.begin();
554 auto scalar_it = op->scalar_args.begin();
555 for (auto arg_type : op->cconv) {
556 py::object obj;
557 if (arg_type == 'c') {
558 TORCH_CHECK(
559 scalar_it != op->scalar_args.end(),
560 "expected too many scalar args");
561 obj = py::reinterpret_borrow<py::object>(
562 py::handle((scalar_it++)->get()));
563 } else if (arg_type == 'd') {
564 TORCH_CHECK(node_it != inputs.end(), "expected too many inputs");
565 obj = py::cast(envFn(*node_it++));
566 } else {
567 throw std::runtime_error("unexpected calling convention");
568 }
569 py_symbolic_args[input_nr++] = obj;
570 }
571
572 WithInsertPoint insert_point_guard(new_block);
573 WithCurrentScope scope_guard(*new_block->owningGraph(), op->scope());
574
575 if (py::hasattr(pyobj, "symbolic")) {
576 // Call the symbolic function
577 // Use a little trampoline function so we can give good error messages
578 // upon argument mismatch
579 // Register as a custom operator
580 // TODO: Find a more elegant way to do this without having to touch
581 // internal Python modules.
582 // TODO(justinchuby): Define a namespace for these Python Ops.
583 onnx_registration.attr("registry")
584 .attr("register")(
585 "::" + op->name(),
586 opset_version,
587 pyobj.attr("symbolic"),
588 /* custom */ true);
589
590 // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to
591 // Python. Check #87343 for details.
592 py::object raw_output = onnx.attr("_run_symbolic_method")(
593 new_block->owningGraph()->shared_from_this(),
594 op->name(),
595 pyobj.attr("symbolic"),
596 py_symbolic_args);
597
598 processSymbolicOutput(op->name(), op, raw_output);
599 } else {
600 TORCH_INTERNAL_ASSERT(is_registered_op);
601 Node* n = static_cast<Node*>(op);
602 n->s_(attr::name, op->name());
603 // Call symbolic function
604 // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to
605 // Python. Check #87343 for details.
606 py::list new_nodes = py::list();
607 py::object raw_output = onnx.attr("_run_symbolic_function")(
608 new_block->owningGraph()->shared_from_this(),
609 new_block,
610 n,
611 py_symbolic_args,
612 env,
613 values_in_env,
614 new_nodes,
615 operator_export_type);
616
617 processSymbolicOutput(op->kind().toUnqualString(), n, raw_output);
618 }
619 };
620
621 auto k = old_node->kind();
622 if (k.is_caffe2()) {
623 // Pass on Caffe2 operator, since we already preprocess it
624 cloneNode(old_node);
625 } else if (k == prim::PythonOp) {
626 callPySymbolicMethod(static_cast<ConcretePythonOp*>(old_node));
627 } else {
628 callPySymbolicFunction(old_node);
629 }
630 }
631
632 } // namespace torch::jit
633