1# mypy: allow-untyped-defs 2"""Functions to export models into the ONNX IR format. 3 4These models can be loaded with the ONNX library and then 5converted to models which run on other deep learning frameworks. 6""" 7 8from __future__ import annotations 9 10import contextlib 11import copy 12import inspect 13import re 14import typing 15import warnings 16from typing import Any, Callable, cast, Collection, Mapping, Sequence 17 18import torch 19import torch._C._onnx as _C_onnx 20import torch.jit._trace 21import torch.serialization 22from torch import _C 23from torch.onnx import ( # noqa: F401 24 _constants, 25 _deprecation, 26 _exporter_states, 27 errors, 28 symbolic_helper, 29) 30from torch.onnx._globals import GLOBALS 31from torch.onnx._internal import diagnostics, jit_utils, onnx_proto_utils, registration 32 33 34__all__ = [ 35 "is_in_onnx_export", 36 "select_model_mode_for_export", 37 "disable_apex_o2_state_dict_hook", 38 "setup_onnx_logging", 39 "exporter_context", 40 "export", 41 "model_signature", 42 "warn_on_static_input_change", 43 "unpack_quantized_tensor", 44 "export_to_pretty_string", 45 "unconvertible_ops", 46 "register_custom_op_symbolic", 47 "unregister_custom_op_symbolic", 48] 49 50 51def is_in_onnx_export() -> bool: 52 """Returns whether it is in the middle of ONNX export.""" 53 return GLOBALS.in_onnx_export 54 55 56# TODO(justinchuby): Remove dependency to this global variable from constant_fold.cpp 57# Skip check due to cannot import IValue from torch._C 58_params_dict = {} # type: ignore[var-annotated] 59 60 61@contextlib.contextmanager 62def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode): 63 r"""A context manager to temporarily set the training mode of ``model`` 64 to ``mode``, resetting it when we exit the with-block. 65 66 Args: 67 model: Same type and meaning as ``model`` arg to :func:`export`. 68 mode: Same type and meaning as ``training`` arg to :func:`export`. 69 """ 70 if not isinstance(mode, _C_onnx.TrainingMode): 71 raise TypeError( 72 f"'mode' should be a torch.onnx.TrainingMode enum, but got '{type(mode)}'." 73 ) 74 originally_training: bool = False 75 76 if hasattr(model, "training"): 77 originally_training = model.training 78 79 # ONNX opset 12 has better support for training amenable models, with updated 80 # versions of the dropout and batch_norm operators 81 if mode == _C_onnx.TrainingMode.TRAINING or ( 82 mode == _C_onnx.TrainingMode.PRESERVE and originally_training 83 ): 84 GLOBALS.export_training = True 85 if GLOBALS.export_onnx_opset_version < 12: 86 warnings.warn( 87 "You are exporting the model in training mode with onnx opset " 88 f"version {GLOBALS.export_onnx_opset_version}. " 89 "Opset versions lower than opset 12 will not be able to export " 90 "nodes such as Dropout and BatchNorm correctly." 91 ) 92 else: 93 GLOBALS.export_training = False 94 95 GLOBALS.training_mode = mode 96 if mode == _C_onnx.TrainingMode.TRAINING: 97 model.train(True) 98 elif mode == _C_onnx.TrainingMode.EVAL: 99 model.train(False) 100 # else mode == _C_onnx.TrainingMode.PRESERVE, do nothing 101 102 try: 103 yield 104 finally: 105 if hasattr(model, "training") and not mode == _C_onnx.TrainingMode.PRESERVE: 106 model.train(originally_training) 107 108 109@contextlib.contextmanager 110def disable_apex_o2_state_dict_hook(model: torch.nn.Module | torch.jit.ScriptFunction): 111 # Apex O2 hook state_dict to return fp16 weights as fp32. 112 # Exporter cannot identify them as same tensors. 113 # Since this hook is only used by optimizer, it is safe to 114 # remove this hook while exporting. 115 if not isinstance(model, torch.jit.ScriptFunction): 116 model_hooks = {} # type: ignore[var-annotated] 117 for module in model.modules(): 118 for key, hook in module._state_dict_hooks.items(): 119 if type(hook).__name__ == "O2StateDictHook": 120 if module not in model_hooks: 121 model_hooks[module] = {} 122 model_hooks[module][key] = hook 123 if module in model_hooks: 124 for key in model_hooks[module]: 125 module._state_dict_hooks.pop(key) 126 try: 127 yield 128 finally: 129 # Add the hooks back 130 for module, m_map in model_hooks.items(): 131 for key, hook in m_map.items(): 132 module._state_dict_hooks[key] = hook 133 else: 134 try: 135 yield 136 finally: 137 pass 138 139 140@contextlib.contextmanager 141def setup_onnx_logging(verbose: bool): 142 is_originally_enabled = torch.onnx.is_onnx_log_enabled() 143 if is_originally_enabled or verbose: 144 torch.onnx.enable_log() 145 try: 146 yield 147 finally: 148 if not is_originally_enabled: 149 torch.onnx.disable_log() 150 151 152@contextlib.contextmanager 153def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool): 154 with select_model_mode_for_export( 155 model, mode 156 ) as mode_ctx, disable_apex_o2_state_dict_hook( 157 model 158 ) as apex_ctx, setup_onnx_logging( 159 verbose 160 ) as log_ctx, diagnostics.create_export_diagnostic_context() as diagnostic_ctx: 161 yield (mode_ctx, apex_ctx, log_ctx, diagnostic_ctx) 162 163 164def _get_torch_export_args( 165 args: tuple[Any, ...], 166 kwargs: dict[str, Any] | None, 167) -> tuple[tuple[Any, ...], dict[str, Any] | None]: 168 """Obtain the arguments for torch.onnx.export from the model and the input arguments.""" 169 if not kwargs and args and isinstance(args[-1], dict): 170 kwargs = args[-1] 171 args = args[:-1] 172 return args, kwargs 173 174 175def export( 176 model: torch.nn.Module | torch.jit.ScriptModule | torch.jit.ScriptFunction, 177 args: tuple[Any, ...] | torch.Tensor, 178 f: str, 179 *, 180 kwargs: dict[str, Any] | None = None, 181 export_params: bool = True, 182 verbose: bool = False, 183 training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, 184 input_names: Sequence[str] | None = None, 185 output_names: Sequence[str] | None = None, 186 operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX, 187 opset_version: int | None = None, 188 do_constant_folding: bool = True, 189 dynamic_axes: Mapping[str, Mapping[int, str]] 190 | Mapping[str, Sequence[int]] 191 | None = None, 192 keep_initializers_as_inputs: bool | None = None, 193 custom_opsets: Mapping[str, int] | None = None, 194 export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False, 195 autograd_inlining: bool = True, 196) -> None: 197 r"""Exports a model into ONNX format. 198 199 If ``model`` is not a :class:`torch.jit.ScriptModule` nor a 200 :class:`torch.jit.ScriptFunction`, this runs 201 ``model`` once in order to convert it to a TorchScript graph to be exported 202 (the equivalent of :func:`torch.jit.trace`). Thus this has the same limited support 203 for dynamic control flow as :func:`torch.jit.trace`. 204 205 Args: 206 model: The model to be exported. 207 args: 208 209 args can be structured either as: 210 211 1. ONLY A TUPLE OF ARGUMENTS:: 212 213 args = (x, y, z) 214 215 The tuple should contain model inputs such that ``model(*args)`` is a valid 216 invocation of the model. Any non-Tensor arguments will be hard-coded into the 217 exported model; any Tensor arguments will become inputs of the exported model, 218 in the order they occur in the tuple. 219 220 2. A TENSOR:: 221 222 args = torch.Tensor([1]) 223 224 This is equivalent to a 1-ary tuple of that Tensor. 225 226 3. A TUPLE OF ARGUMENTS ENDING WITH A DICTIONARY OF NAMED ARGUMENTS:: 227 228 args = (x, {"y": input_y, "z": input_z}) 229 230 All but the last element of the tuple will be passed as non-keyword arguments, 231 and named arguments will be set from the last element. If a named argument is 232 not present in the dictionary, it is assigned the default value, or None if a 233 default value is not provided. 234 235 .. warning:: 236 This behavior will be deprecated in a future release. Please use the 237 kwargs argument instead. 238 239 .. note:: 240 If a dictionary is the last element of the args tuple, it will be 241 interpreted as containing named arguments. In order to pass a dict as the 242 last non-keyword arg, provide an empty dict as the last element of the args 243 tuple. For example, instead of:: 244 245 torch.onnx.export( 246 model, 247 ( 248 x, 249 # WRONG: will be interpreted as named arguments 250 {y: z}, 251 ), 252 "test.onnx.pb", 253 ) 254 255 Write:: 256 257 torch.onnx.export(model, (x, {y: z}, {}), "test.onnx.pb") 258 259 f: Path to the output ONNX model file. E.g. "model.onnx". 260 kwargs: Named arguments to the model. 261 export_params: If True, all parameters will 262 be exported. Set this to False if you want to export an untrained model. 263 In this case, the exported model will first take all of its parameters 264 as arguments, with the ordering as specified by ``model.state_dict().values()`` 265 verbose: if True, prints a description of the 266 model being exported to stdout. In addition, the final ONNX graph will include the 267 field ``doc_string``` from the exported model which mentions the source code locations 268 for ``model``. If True, ONNX exporter logging will be turned on. 269 training: 270 * ``TrainingMode.EVAL``: export the model in inference mode. 271 * ``TrainingMode.PRESERVE``: export the model in inference mode if model.training is 272 False and in training mode if model.training is True. 273 * ``TrainingMode.TRAINING``: export the model in training mode. Disables optimizations 274 which might interfere with training. 275 input_names (list of str, default empty list): names to assign to the 276 input nodes of the graph, in order. 277 output_names (list of str, default empty list): names to assign to the 278 output nodes of the graph, in order. 279 operator_export_type (enum, default OperatorExportTypes.ONNX): 280 281 .. warning:: 282 This option will be deprecated in a future release. Future exported 283 graphs will always use the default opset domain. 284 285 * ``OperatorExportTypes.ONNX``: Export all ops as regular ONNX ops 286 (in the default opset domain). 287 * ``OperatorExportTypes.ONNX_FALLTHROUGH``: Try to convert all ops 288 to standard ONNX ops in the default opset domain. If unable to do so 289 (e.g. because support has not been added to convert a particular torch op to ONNX), 290 fall back to exporting the op into a custom opset domain without conversion. Applies 291 to `custom ops <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_ 292 as well as ATen ops. For the exported model to be usable, the runtime must support 293 these non-standard ops. 294 * ``OperatorExportTypes.ONNX_ATEN``: All ATen ops (in the TorchScript namespace "aten") 295 are exported as ATen ops (in opset domain "org.pytorch.aten"). 296 `ATen <https://pytorch.org/cppdocs/#aten>`_ is PyTorch's built-in tensor library, so 297 this instructs the runtime to use PyTorch's implementation of these ops. 298 299 .. warning:: 300 301 Models exported this way are probably runnable only by Caffe2. 302 303 This may be useful if the numeric differences in implementations of operators are 304 causing large differences in behavior between PyTorch and Caffe2 (which is more 305 common on untrained models). 306 307 * ``OperatorExportTypes.ONNX_ATEN_FALLBACK``: Try to export each ATen op 308 (in the TorchScript namespace "aten") as a regular ONNX op. If we are unable to do so 309 (e.g. because support has not been added to convert a particular torch op to ONNX), 310 fall back to exporting an ATen op. See documentation on OperatorExportTypes.ONNX_ATEN for 311 context. 312 For example:: 313 314 graph(%0 : Float): 315 %3 : int = prim::Constant[value=0]() 316 # conversion unsupported 317 %4 : Float = aten::triu(%0, %3) 318 # conversion supported 319 %5 : Float = aten::mul(%4, %0) 320 return (%5) 321 322 Assuming ``aten::triu`` is not supported in ONNX, this will be exported as:: 323 324 graph(%0 : Float): 325 %1 : Long() = onnx::Constant[value={0}]() 326 # not converted 327 %2 : Float = aten::ATen[operator="triu"](%0, %1) 328 # converted 329 %3 : Float = onnx::Mul(%2, %0) 330 return (%3) 331 332 .. warning:: 333 334 Models exported this way are probably runnable only by Caffe2. 335 336 opset_version (int, default 17): The version of the 337 `default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_ 338 to target. Must be >= 7 and <= 17. 339 do_constant_folding: Apply the constant-folding optimization. 340 Constant-folding will replace some of the ops that have all constant inputs 341 with pre-computed constant nodes. 342 dynamic_axes: 343 344 By default the exported model will have the shapes of all input and output tensors 345 set to exactly match those given in ``args``. To specify axes of tensors as 346 dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema: 347 348 * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or 349 ``output_names``. 350 * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a 351 list, each element is an axis index. 352 353 For example:: 354 355 class SumModule(torch.nn.Module): 356 def forward(self, x): 357 return torch.sum(x, dim=1) 358 359 360 torch.onnx.export( 361 SumModule(), 362 (torch.ones(2, 2),), 363 "onnx.pb", 364 input_names=["x"], 365 output_names=["sum"], 366 ) 367 368 Produces:: 369 370 input { 371 name: "x" 372 ... 373 shape { 374 dim { 375 dim_value: 2 # axis 0 376 } 377 dim { 378 dim_value: 2 # axis 1 379 ... 380 output { 381 name: "sum" 382 ... 383 shape { 384 dim { 385 dim_value: 2 # axis 0 386 ... 387 388 While:: 389 390 torch.onnx.export( 391 SumModule(), 392 (torch.ones(2, 2),), 393 "onnx.pb", 394 input_names=["x"], 395 output_names=["sum"], 396 dynamic_axes={ 397 # dict value: manually named axes 398 "x": {0: "my_custom_axis_name"}, 399 # list value: automatic names 400 "sum": [0], 401 }, 402 ) 403 404 Produces:: 405 406 input { 407 name: "x" 408 ... 409 shape { 410 dim { 411 dim_param: "my_custom_axis_name" # axis 0 412 } 413 dim { 414 dim_value: 2 # axis 1 415 ... 416 output { 417 name: "sum" 418 ... 419 shape { 420 dim { 421 dim_param: "sum_dynamic_axes_1" # axis 0 422 ... 423 424 keep_initializers_as_inputs: If True, all the 425 initializers (typically corresponding to parameters) in the 426 exported graph will also be added as inputs to the graph. If False, 427 then initializers are not added as inputs to the graph, and only 428 the non-parameter inputs are added as inputs. 429 This may allow for better optimizations (e.g. constant folding) by 430 backends/runtimes. 431 432 If True, `deduplicate_initializers` pass will not be executed. This means 433 initializers with duplicated values will not be deduplicated and 434 will be treated as distinct inputs to the graph. This allows different 435 input initializers to be supplied at the runtime following export. 436 437 If ``opset_version < 9``, initializers MUST be part of graph 438 inputs and this argument will be ignored and the behavior will be 439 equivalent to setting this argument to True. 440 441 custom_opsets (dict[str, int], default empty dict): A dict with schema: 442 443 * KEY (str): opset domain name 444 * VALUE (int): opset version 445 446 If a custom opset is referenced by ``model`` but not mentioned in this dictionary, 447 the opset version is set to 1. Only custom opset domain name and version should be 448 indicated through this argument. 449 450 export_modules_as_functions: Flag to enable 451 exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the 452 particular types of modules to export as local functions in ONNX. 453 This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because 454 ``opset_version`` < 15 implies IR version < 8, which means no local function support. 455 Module variables will be exported as function attributes. There are two categories of function 456 attributes. 457 458 1. Annotated attributes: class variables that have type annotations via 459 `PEP 526-style <https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations>`_ 460 will be exported as attributes. 461 Annotated attributes are not used inside the subgraph of ONNX local function because 462 they are not created by PyTorch JIT tracing, but they may be used by consumers 463 to determine whether or not to replace the function with a particular fused kernel. 464 465 2. Inferred attributes: variables that are used by operators inside the module. Attribute names 466 will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from 467 python module annotations. Inferred attributes are used inside the subgraph of ONNX local function. 468 469 * ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes. 470 * ``True``: export all ``nn.Module`` forward calls as local function nodes. 471 * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes, 472 only if the type of the ``nn.Module`` is found in the set. 473 474 autograd_inlining: Flag used to control whether to inline autograd functions. 475 Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. 476 477 Raises: 478 :class:`torch.onnx.errors.CheckerError`: If the ONNX checker detects an invalid ONNX graph. 479 :class:`torch.onnx.errors.UnsupportedOperatorError`: If the ONNX graph cannot be exported because it 480 uses an operator that is not supported by the exporter. 481 :class:`torch.onnx.errors.OnnxExporterError`: Other errors that can occur during export. 482 All errors are subclasses of :class:`errors.OnnxExporterError`. 483 """ 484 if operator_export_type != _C_onnx.OperatorExportTypes.ONNX: 485 warnings.warn( 486 "Setting `operator_export_type` to something other than default is deprecated. " 487 "The option will be removed in a future release.", 488 category=FutureWarning, 489 ) 490 if training == _C_onnx.TrainingMode.TRAINING: 491 warnings.warn( 492 "Setting `training` to something other than default is deprecated. " 493 "The option will be removed in a future release. Please set the training mode " 494 "before exporting the model.", 495 category=FutureWarning, 496 ) 497 498 args = (args,) if isinstance(args, torch.Tensor) else args 499 if kwargs is not None: 500 args = args + (kwargs,) 501 502 _export( 503 model, 504 args, 505 f, 506 export_params, 507 verbose, 508 training, 509 input_names, 510 output_names, 511 operator_export_type=operator_export_type, 512 opset_version=opset_version, 513 do_constant_folding=do_constant_folding, 514 dynamic_axes=dynamic_axes, 515 keep_initializers_as_inputs=keep_initializers_as_inputs, 516 custom_opsets=custom_opsets, 517 export_modules_as_functions=export_modules_as_functions, 518 autograd_inlining=autograd_inlining, 519 ) 520 521 return None 522 523 524def _is_constant_tensor_list(node): 525 if node.kind() != "prim::Constant": 526 return False 527 output_type = node.output().type() 528 if output_type.isSubtypeOf(_C.ListType.ofTensors()): 529 return True 530 if output_type.isSubtypeOf(_C.ListType(_C.OptionalType.ofTensor())): 531 return True 532 533 534# ONNX can't handle constants that are lists of tensors, which can 535# get generated in constant prop. So we split them back into prim::ListConstructs 536 537 538def _split_tensor_list_constants(g, block): 539 for node in block.nodes(): 540 for subblock in node.blocks(): 541 _split_tensor_list_constants(g, subblock) 542 if _is_constant_tensor_list(node): 543 inputs = [] 544 for val in node.output().toIValue(): 545 input = g.insertConstant(val) 546 input.node().moveBefore(node) 547 input.node().copyMetadata(node) 548 inputs.append(input) 549 550 lc = ( 551 g.create("prim::ListConstruct", inputs) 552 .insertBefore(node) 553 .output() 554 .setType(_C.ListType.ofTensors()) 555 ) 556 lc.node().copyMetadata(node) 557 node.output().replaceAllUsesWith(lc) 558 559 560def _optimize_graph( 561 graph: _C.Graph, 562 operator_export_type: _C_onnx.OperatorExportTypes, 563 _disable_torch_constant_prop: bool = False, 564 fixed_batch_size: bool = False, 565 params_dict=None, 566 dynamic_axes=None, 567 input_names=None, 568 module=None, 569): 570 if params_dict is None: 571 params_dict = {} 572 573 # Inline everything 574 _C._jit_pass_inline(graph) 575 576 # Remove fork/wait nodes 577 _C._jit_pass_inline_fork_wait(graph) 578 _C._jit_pass_lint(graph) 579 if GLOBALS.autograd_inlining: 580 _C._jit_pass_onnx_autograd_function_process(graph) 581 _C._jit_pass_lower_all_tuples(graph) 582 583 # we now record some ops like ones/zeros 584 # into a trace where we previously recorded constants. 585 # use constant prop to maintain our current level of onnx support 586 # without implementing symbolics for all of them 587 if _disable_torch_constant_prop is False: 588 _C._jit_pass_constant_propagation(graph) 589 590 _split_tensor_list_constants(graph, graph) 591 # run dce to eliminate dead parts of the graph that might have been 592 # left behind by things like symbolic_override 593 _C._jit_pass_dce(graph) 594 _C._jit_pass_lint(graph) 595 596 # CSE should improve perf when Autocast is used with disabled cache 597 # Autocast is disabled due to a limitation on tracer as described at https://github.com/pytorch/pytorch/issues/84092 598 # Must run before _C._jit_pass_erase_number_types to prevent type substitution 599 if _C._jit_pass_cse(graph): 600 _C._jit_pass_onnx_lint(graph) 601 602 _C._jit_pass_canonicalize_graph_fuser_ops(graph) 603 _C._jit_pass_lint(graph) 604 _C._jit_pass_peephole(graph, True) 605 _C._jit_pass_fuse_addmm(graph) 606 _C._jit_pass_lint(graph) 607 608 _C._jit_pass_peephole(graph, True) 609 _C._jit_pass_lower_all_tuples(graph) 610 # in _jit_pass_onnx, symbolic functions are called for each node for conversion. 611 # However, there are nodes that cannot be converted without additional context. 612 # For example, the number of outputs from split (and whether it is static or dynamic) is unknown 613 # until the point where it is unpacked by listUnpack node. 614 # This pass does a preprocess, and prepares the nodes such that enough context can be received 615 # by the symbolic function. 616 _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module) 617 _C._jit_pass_onnx_preprocess(graph) 618 619 # onnx does not support tuples, so try to remove them 620 _C._jit_pass_lint(graph) 621 622 # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0 623 _C._jit_pass_prepare_division_for_onnx(graph) 624 625 _C._jit_pass_onnx_remove_print(graph) 626 _C._jit_pass_onnx_preprocess_caffe2(graph) 627 628 symbolic_helper._quantized_ops.clear() 629 # Unpack quantized weights for conv and linear ops and insert into graph. 630 _C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict) 631 # onnx only supports tensors, so we turn all out number types into tensors 632 _C._jit_pass_erase_number_types(graph) 633 if GLOBALS.onnx_shape_inference: 634 input_names = [] if input_names is None else input_names 635 dynamic_axes = {} if dynamic_axes is None else dynamic_axes 636 _C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names) 637 _C._jit_pass_onnx_lint(graph) 638 639 graph = _C._jit_pass_onnx(graph, operator_export_type) 640 _C._jit_pass_onnx_lint(graph) 641 _C._jit_pass_lint(graph) 642 643 _C._jit_pass_onnx_scalar_type_analysis( 644 graph, True, GLOBALS.export_onnx_opset_version 645 ) 646 _C._jit_pass_lint(graph) 647 648 _C._jit_pass_onnx_peephole( 649 graph, GLOBALS.export_onnx_opset_version, fixed_batch_size 650 ) 651 _C._jit_pass_lint(graph) 652 653 # graph is not a valid jit graph anymore because types have been replaced 654 # (e.g. int with Tensor), so it now contains operators that don't actually 655 # exist. We can't run normal dead code elimination because it'd fail trying 656 # to look up if an operator has side effects, but we can run a dead code 657 # elimination variant that doesn't need to look up if an op has side effects. 658 _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) 659 _C._jit_pass_lint(graph) 660 graph = _C._jit_pass_canonicalize(graph) 661 _C._jit_pass_lint(graph) 662 if GLOBALS.onnx_shape_inference: 663 _C._jit_pass_onnx_graph_shape_type_inference( 664 graph, params_dict, GLOBALS.export_onnx_opset_version 665 ) 666 667 return graph 668 669 670def warn_on_static_input_change(input_states): 671 """Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph. 672 673 We accept dictionaries and strings as ONNX inputs, but they should be only for 674 configuration use. we detect here if these inputs are modified, and if so we warn 675 the user that the changes won't take effect in the traced ONNX graph. 676 """ 677 for input, traced_input in zip(input_states[0], input_states[1]): 678 if isinstance(input, dict): 679 if list(input.keys()) != list(traced_input.keys()): 680 warning = ( 681 "We detected that you are modifying a dictionary that is an input to your " 682 "model. " 683 "Note that dictionaries are allowed as inputs in ONNX but they should be " 684 "handled with care. " 685 "Usages of dictionaries is not recommended, and should not be used except " 686 "for configuration use. " 687 "Also note that the order and values of the keys must remain the same. " 688 ) 689 warnings.warn(warning) 690 elif isinstance(input, str): 691 if input != traced_input: 692 warning = ( 693 "The model seems to have string inputs/outputs. " 694 "Note that strings will not appear as inputs/outputs of the ONNX graph. " 695 ) 696 warnings.warn(warning) 697 698 699def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type): 700 """Resolves the arguments that are ignored when export_type != operator_export_type.ONNX.""" 701 return arg_value 702 703 704def _decide_keep_init_as_input( 705 keep_initializers_as_inputs: bool | None, 706 operator_export_type: _C_onnx.OperatorExportTypes, 707 opset_version: int, 708): 709 """Decides whether the initializers in the graph should be listed as ONNX graph inputs. 710 711 This method encapsulates the logic to decide whether the initializers in the graph 712 should be listed as ONNX graph inputs (i.e., whether to choose ONNX IR v3 or v4). 713 If keep_initializers_as_inputs is not specified (None), then we decide whether to keep 714 initializers as graph inputs (val_keep_init_as_ip) based on export type. If export type 715 is ONNX, then do not keep initializers as input (val_keep_init_as_ip=False). For all other 716 export types keep initializers as input (val_keep_init_as_ip=True). 717 If keep_initializers_as_inputs is specified, then respect it. Unless opset version <= 8, 718 in which case it must be ignored because for opset version <= 8, all initializers MUST be 719 part of graph input (only ONNX IR v3 is allowed), i.e. val_keep_init_as_ip=True. 720 721 Special handling is needed for opset version 8 or lower, because irrespective 722 of user input for keep_initializers_as_inputs, the graph must follow ONNX IR v3 723 semantics, i.e. all initializers must be listed as ONNX graph input. 724 """ 725 726 if opset_version < 9: 727 if keep_initializers_as_inputs is False: 728 warnings.warn( 729 "Setting 'keep_initializers_as_inputs=False' for opset version" 730 "8 or lower would lead to an invalid ONNX graph. Therefore, " 731 "'keep_initializers_as_inputs=False' is ignored during export." 732 "Exported model will have initializers as graph inputs (compliant " 733 " to ONNX IR v3)." 734 ) 735 return True # i.e. True == initializers are part of graph input (ONNX IR v3) 736 val_keep_init_as_ip = ( 737 True if keep_initializers_as_inputs is None else keep_initializers_as_inputs 738 ) 739 if ( 740 keep_initializers_as_inputs is None 741 and operator_export_type is _C_onnx.OperatorExportTypes.ONNX 742 ): 743 val_keep_init_as_ip = False 744 return val_keep_init_as_ip 745 746 747def _decide_add_node_names(add_node_names, operator_export_type): 748 return _resolve_args_by_export_type( 749 "add_node_names", add_node_names, operator_export_type 750 ) 751 752 753def _decide_constant_folding(do_constant_folding, operator_export_type, training): 754 do_constant_folding = _resolve_args_by_export_type( 755 "do_constant_folding", do_constant_folding, operator_export_type 756 ) 757 if do_constant_folding and ( 758 training is not None and training is not _C_onnx.TrainingMode.EVAL 759 ): 760 warnings.warn( 761 "It is recommended that constant folding be turned off ('do_constant_folding=False') " 762 "when exporting the model in training-amenable mode, i.e. with 'training=TrainingMode.TRAIN' " 763 "or 'training=TrainingMode.PRESERVE' (when model is in training mode). Otherwise, some " 764 "learnable model parameters may not translate correctly in the exported ONNX model " 765 "because constant folding mutates model parameters. Please consider " 766 "turning off constant folding or setting the training=TrainingMode.EVAL." 767 ) 768 return do_constant_folding 769 770 771def _signature(model) -> inspect.Signature: 772 should_be_callable = getattr(model, "forward", model) 773 if callable(should_be_callable): 774 return inspect.signature(should_be_callable) 775 raise ValueError("model has no forward method and is not callable") 776 777 778def _decide_input_format(model, args): 779 try: 780 sig = _signature(model) 781 except ValueError as e: 782 warnings.warn(f"{e}, skipping _decide_input_format") 783 return args 784 try: 785 ordered_list_keys = list(sig.parameters.keys()) 786 if ordered_list_keys[0] == "self": 787 ordered_list_keys = ordered_list_keys[1:] 788 args_dict: dict = {} 789 if isinstance(args, list): 790 args_list = args 791 elif isinstance(args, tuple): 792 args_list = list(args) 793 else: 794 args_list = [args] 795 if isinstance(args_list[-1], dict): 796 args_dict = args_list[-1] 797 args_list = args_list[:-1] 798 n_nonkeyword = len(args_list) 799 for optional_arg in ordered_list_keys[n_nonkeyword:]: 800 if optional_arg in args_dict: 801 args_list.append(args_dict[optional_arg]) 802 # Check if this arg has a default value 803 else: 804 param = sig.parameters[optional_arg] 805 if param.default != param.empty: 806 args_list.append(param.default) 807 args = args_list if isinstance(args, list) else tuple(args_list) 808 # Cases of models with no input args 809 except IndexError: 810 warnings.warn("No input args, skipping _decide_input_format") 811 except Exception as e: 812 warnings.warn(f"Skipping _decide_input_format\n {e.args[0]}") 813 return args 814 815 816def _from_dynamic_axes_to_dynamic_shapes( 817 model, 818 dynamic_axes: Mapping[str, Mapping[int, str]] 819 | Mapping[str, Sequence[int]] 820 | None = None, 821 input_names: Sequence[str] | None = None, 822) -> dict[str, Any] | None: 823 """ 824 825 dynamic_axes examples: 826 (1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}} 827 (2) dynamic_axes = {"x": [0], "y": [1]} 828 829 these will be converted to dynamic_shapes respectively: 830 (1) dynamic_shapes = {"x": {0: Dim("my_custom_axis_name_1")}, "y": {1: Dim("my_custom_axis_name_2")}} 831 (2) dynamic_shapes = {"x": {0: Dim("x_dim_0")}, "y": {1: Dim("y_dim_1")}} # auto-generated dim names 832 833 """ 834 if dynamic_axes is None: 835 return None 836 837 if input_names is None: 838 input_names_set = set() 839 else: 840 input_names_set = set(input_names) 841 842 dynamic_shapes: dict[str, Any | None] = {} 843 for input_name, axes in dynamic_axes.items(): 844 if input_name in input_names_set: 845 raise ValueError( 846 "Assinging new input names is not supported yet. Please use model forward signature " 847 "to specify input names in dynamix_axes." 848 ) 849 if isinstance(axes, dict): 850 dynamic_shapes[input_name] = { 851 k: torch.export.Dim(v) for k, v in axes.items() 852 } 853 elif isinstance(axes, list): 854 dynamic_shapes[input_name] = { 855 k: torch.export.Dim(f"{input_name}_dim_{k}") for k in axes 856 } 857 else: 858 raise TypeError( 859 f"dynamic_axes value must be either a dict or a list, but got {type(axes)}" 860 ) 861 # torch.export.export needs static dim to present in dynamic_shapes 862 # for all input tensors, so we need to add them with None 863 try: 864 sig = _signature(model) 865 except ValueError as e: 866 warnings.warn(f"{e}, skipping auto filling None on static axes...") 867 return dynamic_shapes 868 for input_name in sig.parameters.keys(): 869 if input_name not in dynamic_shapes: 870 dynamic_shapes[input_name] = None 871 return dynamic_shapes 872 873 874def _trace(func, args, operator_export_type, return_outs=False): 875 # Special case for common case of passing a single Tensor 876 if isinstance(args, torch.Tensor): 877 args = (args,) 878 879 trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( 880 func, 881 args, 882 strict=False, 883 _force_outplace=False, 884 _return_inputs_states=True, 885 ) 886 warn_on_static_input_change(inputs_states) 887 888 trace_graph = _optimize_graph(trace_graph, operator_export_type, params_dict={}) 889 if return_outs: 890 return trace_graph, torch_out 891 return trace_graph 892 893 894def _trace_and_get_graph_from_model(model, args): 895 # A basic sanity check: make sure the state_dict keys are the same 896 # before and after running the model. Fail fast! 897 orig_state_dict_keys = torch.jit._unique_state_dict(model).keys() 898 899 # Disable Autocast cache because it replaces kernel's weight and bias 900 # by (undesired) constants. 901 # No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665 902 prev_autocast_cache_enabled = torch.is_autocast_cache_enabled() 903 torch.set_autocast_cache_enabled(False) 904 trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( 905 model, 906 args, 907 strict=False, 908 _force_outplace=False, 909 _return_inputs_states=True, 910 ) 911 torch.set_autocast_cache_enabled(prev_autocast_cache_enabled) 912 913 warn_on_static_input_change(inputs_states) 914 915 if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys(): 916 raise RuntimeError( 917 "state_dict changed after running the tracer; " 918 "something weird is happening in your model!" 919 ) 920 921 return trace_graph, torch_out 922 923 924def _get_param_count_list(method_graph, args_params): 925 param_count_list = [] 926 for input_, arg_params_ in zip(method_graph.inputs(), args_params): 927 if "PackedParams" in str(input_.type()): 928 in_vars, _ = torch.jit._flatten(arg_params_) 929 param_count_list.append(len(in_vars)) 930 else: 931 param_count_list.append(arg_params_ is not None) 932 933 return param_count_list 934 935 936def _check_flatten_did_not_remove(original, jit_flattened): 937 """torch.jit._flatten removes None. Check if it did so in this case.""" 938 939 def flatten(x): 940 if isinstance(x, (list, tuple)): 941 for inner in x: 942 yield from flatten(inner) 943 elif isinstance(x, dict): 944 for inner in x.values(): 945 yield from flatten(inner) 946 else: 947 yield x 948 949 flattened_with_none = list(flatten(original)) 950 num_none = len(flattened_with_none) - len(jit_flattened) 951 assert num_none >= 0 952 if num_none: 953 raise ValueError( 954 f"args contained {num_none} None's after flattening. " 955 "When exporting a ScriptModule or ScriptFunction, no args may " 956 "be None because that breaks type propagation." 957 ) 958 959 960def _create_jit_graph( 961 model: torch.nn.Module | torch.jit.ScriptFunction, args: Sequence[Any] 962) -> tuple[_C.Graph, list[_C.IValue], Any | None, _C.ScriptModule | None]: 963 if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)): 964 flattened_args = tuple(torch.jit._flatten(tuple(args))[0]) 965 _check_flatten_did_not_remove(args, flattened_args) 966 torch_out = None 967 968 if isinstance(model, torch.jit.ScriptModule): 969 try: 970 graph = model.forward.graph # type: ignore[attr-defined] 971 except AttributeError as e: 972 raise RuntimeError("'forward' method must be a script method") from e 973 _C._jit_pass_onnx_function_substitution(graph) 974 freezed_module = _C._freeze_module( 975 cast(_C.ScriptModule, model._c), preserveParameters=True 976 ) 977 module, params = _C._jit_onnx_list_model_parameters(freezed_module) 978 method_graph = module._get_method("forward").graph 979 args_params = tuple(args) + tuple(params) 980 param_count_list = _get_param_count_list(method_graph, args_params) 981 in_vars, _ = torch.jit._flatten(args_params) 982 graph = _C._propagate_and_assign_input_shapes( 983 method_graph, tuple(in_vars), param_count_list, False, False 984 ) 985 return graph, params, torch_out, module 986 987 # torch.jit.ScriptFunction 988 params = [] 989 graph = model.graph 990 _C._jit_pass_onnx_function_substitution(graph) 991 param_count_list = _get_param_count_list(graph, args) 992 graph = _C._propagate_and_assign_input_shapes( 993 graph, flattened_args, param_count_list, False, False 994 ) 995 return graph, params, torch_out, None 996 997 graph, torch_out = _trace_and_get_graph_from_model(model, args) 998 _C._jit_pass_onnx_lint(graph) 999 state_dict = torch.jit._unique_state_dict(model) 1000 params = list(state_dict.values()) 1001 graph_inputs = list(graph.inputs()) 1002 user_input_num = len(graph_inputs) - len(state_dict) 1003 param_names = list(state_dict.keys()) 1004 for i, inp in enumerate(graph_inputs): 1005 if i >= user_input_num: 1006 inp.setDebugName(param_names[i - user_input_num]) 1007 _C._jit_pass_onnx_function_substitution(graph) 1008 return graph, params, torch_out, None 1009 1010 1011def _get_named_param_dict(graph, params): 1012 input_and_param_names = [val.debugName() for val in graph.inputs()] 1013 param_names = input_and_param_names[len(input_and_param_names) - len(params) :] 1014 _params_dict = dict(zip(param_names, params)) 1015 return _params_dict 1016 1017 1018def _get_example_outputs(model, args): 1019 input_args = copy.deepcopy(args) 1020 input_kwargs = {} 1021 if input_args and isinstance(input_args[-1], dict): 1022 input_kwargs = input_args[-1] 1023 input_args = input_args[:-1] 1024 1025 example_outputs = model(*input_args, **input_kwargs) 1026 if isinstance(example_outputs, list): 1027 example_outputs = [example_outputs] 1028 elif not isinstance(example_outputs, tuple): 1029 example_outputs = (example_outputs,) 1030 1031 return example_outputs 1032 1033 1034_qtype_vtype_map = { 1035 torch.quint8: torch.uint8, 1036 torch.qint8: torch.int8, 1037 torch.qint32: torch.int32, 1038 torch.quint4x2: torch.int8, 1039} 1040 1041 1042def unpack_quantized_tensor(value, cast_onnx_accepted=True): 1043 if isinstance(value, torch.Tensor) and value.dtype in _qtype_vtype_map: 1044 q_value_dequantize = value.dequantize() 1045 q_scale = ( 1046 torch.tensor(value.q_scale(), dtype=torch.double) 1047 if cast_onnx_accepted 1048 else torch.tensor(value.q_scale(), dtype=torch.float32) 1049 ) 1050 q_zero_point = ( 1051 torch.tensor(value.q_zero_point(), dtype=torch.int64) 1052 if cast_onnx_accepted 1053 else torch.tensor(value.q_zero_point(), dtype=_qtype_vtype_map[value.dtype]) 1054 ) 1055 q_value = q_value_dequantize / q_scale + q_zero_point 1056 q_value = q_value.to(dtype=_qtype_vtype_map[value.dtype]) 1057 return q_value, q_scale, q_zero_point 1058 else: 1059 return (value,) 1060 1061 1062def _pre_trace_quant_model(model, args): 1063 r"""Returns `torch.jit.trace(model, args)` if model is quantized. Otherwise do nothing and return 1064 original model. 1065 1066 This is due to https://github.com/pytorch/pytorch/issues/75761. 1067 """ 1068 if any( 1069 hasattr(m, "_packed_params") for m in getattr(model, "modules", list)() 1070 ) or any(getattr(arg, "is_quantized", False) for arg in args): 1071 return torch.jit.trace(model, args) 1072 return model 1073 1074 1075def _model_to_graph( 1076 model, 1077 args, 1078 verbose=False, 1079 input_names=None, 1080 output_names=None, 1081 operator_export_type=_C_onnx.OperatorExportTypes.ONNX, 1082 do_constant_folding=True, 1083 _disable_torch_constant_prop=False, 1084 fixed_batch_size=False, 1085 training=_C_onnx.TrainingMode.EVAL, 1086 dynamic_axes=None, 1087) -> tuple[ 1088 _C.Graph, 1089 dict[str, torch.Tensor], 1090 torch.Tensor 1091 | tuple[torch.Tensor, ...] 1092 | list[torch.Tensor] 1093 | dict[str, torch.Tensor] 1094 | Any 1095 | None, 1096]: 1097 """Converts model into an ONNX graph. 1098 1099 Returns: 1100 graph: A TorchScript IR Graph with ONNX nodes. 1101 params_dict: Dict from input param name to param value. 1102 torch_out: The output tensors resulting from the trace of ``model``. 1103 If ``model`` is a :class:`torch.jit.ScriptModule` or :class:`torch.jit.ScriptFunction`, 1104 this will be None, since we are not doing any tracing. 1105 """ 1106 # TODO: can we simplify this to always return a tuple of Tensor or None? 1107 1108 # Special case for common case of passing a single Tensor 1109 if isinstance(args, (torch.Tensor, int, float, bool)): 1110 args = (args,) 1111 1112 model = _pre_trace_quant_model(model, args) 1113 graph, params, torch_out, module = _create_jit_graph(model, args) 1114 params_dict = _get_named_param_dict(graph, params) 1115 1116 try: 1117 graph = _optimize_graph( 1118 graph, 1119 operator_export_type, 1120 _disable_torch_constant_prop=_disable_torch_constant_prop, 1121 fixed_batch_size=fixed_batch_size, 1122 params_dict=params_dict, 1123 dynamic_axes=dynamic_axes, 1124 input_names=input_names, 1125 module=module, 1126 ) 1127 except Exception as e: 1128 torch.onnx.log("Torch IR graph at exception: ", graph) 1129 raise 1130 1131 is_script = isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)) 1132 if is_script: 1133 example_outputs = _get_example_outputs(model, args) 1134 example_outputs_final = () 1135 for example_output in example_outputs: 1136 example_outputs_final += unpack_quantized_tensor(example_output) 1137 out_vars, desc = torch.jit._flatten(example_outputs_final) 1138 _C._jit_pass_onnx_assign_output_shape( 1139 graph, 1140 out_vars, 1141 desc, 1142 GLOBALS.onnx_shape_inference, 1143 is_script, 1144 GLOBALS.export_onnx_opset_version, 1145 ) 1146 1147 # NB: ONNX requires complete information about output types, which might be 1148 # erased by some optimizations, so we need to set it explicitly again. 1149 else: 1150 if not isinstance(torch_out, (list, tuple)): 1151 output_wrapped = [torch_out] 1152 else: 1153 output_wrapped = torch_out # type: ignore[assignment] 1154 1155 output_tensors, out_desc = torch.jit._flatten(tuple(output_wrapped)) 1156 # assign_output_shape pass is not compatible with quantized outputs. 1157 # Quantized outputs are flattened to 3 values in ONNX, while packed as 1158 # single value in PyTorch. 1159 if not any(getattr(out, "is_quantized", False) for out in output_tensors): 1160 _C._jit_pass_onnx_assign_output_shape( 1161 graph, 1162 output_tensors, 1163 out_desc, 1164 GLOBALS.onnx_shape_inference, 1165 is_script, 1166 GLOBALS.export_onnx_opset_version, 1167 ) 1168 1169 _set_input_and_output_names(graph, input_names, output_names) 1170 params_dict = _get_named_param_dict(graph, params) 1171 1172 if ( 1173 do_constant_folding 1174 and GLOBALS.export_onnx_opset_version 1175 >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET 1176 ): 1177 if training is None or training == _C_onnx.TrainingMode.EVAL: 1178 params_dict = _C._jit_pass_onnx_eval_peephole(graph, params_dict) 1179 1180 params_dict = _C._jit_pass_onnx_constant_fold( 1181 graph, params_dict, GLOBALS.export_onnx_opset_version 1182 ) 1183 _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) 1184 1185 if GLOBALS.onnx_shape_inference: 1186 _C._jit_pass_onnx_graph_shape_type_inference( 1187 graph, params_dict, GLOBALS.export_onnx_opset_version 1188 ) 1189 1190 params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) 1191 1192 # For ONNX opset < 9, constants only have three data types: float16, float, double. 1193 # In this pass transform constants of other data types to float/double + cast operator. 1194 if GLOBALS.export_onnx_opset_version < 9: 1195 _C._jit_pass_onnx_cast_all_constant_to_floating(graph) 1196 1197 params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict) 1198 _C._jit_decay_packed_param_input_types(graph) 1199 1200 # If output names lack a proper name and are identified only by their unique 1201 # give them a legible name for debugging purposes 1202 _apply_friendly_debug_names(graph, params_dict) 1203 1204 return graph, params_dict, torch_out 1205 1206 1207@torch._disable_dynamo 1208@_deprecation.deprecated("2.5", "the future", "use onnx.printer.to_text() instead") 1209def export_to_pretty_string( 1210 model, 1211 args, 1212 export_params=True, 1213 verbose=False, 1214 training=_C_onnx.TrainingMode.EVAL, 1215 input_names=None, 1216 output_names=None, 1217 operator_export_type=_C_onnx.OperatorExportTypes.ONNX, 1218 export_type=None, 1219 google_printer=False, 1220 opset_version=None, 1221 keep_initializers_as_inputs=None, 1222 custom_opsets=None, 1223 add_node_names=True, 1224 do_constant_folding=True, 1225 dynamic_axes=None, 1226): 1227 """Similar to :func:`export`, but returns a text representation of the ONNX model. 1228 1229 Only differences in args listed below. All other args are the same 1230 as :func:`export`. 1231 1232 Args: 1233 add_node_names (bool, default True): Whether or not to set 1234 NodeProto.name. This makes no difference unless 1235 ``google_printer=True``. 1236 google_printer (bool, default False): If False, will return a custom, 1237 compact representation of the model. If True will return the 1238 protobuf's `Message::DebugString()`, which is more verbose. 1239 1240 Returns: 1241 A UTF-8 str containing a human-readable representation of the ONNX model. 1242 """ 1243 if opset_version is None: 1244 opset_version = _constants.ONNX_DEFAULT_OPSET 1245 if custom_opsets is None: 1246 custom_opsets = {} 1247 GLOBALS.export_onnx_opset_version = opset_version 1248 GLOBALS.operator_export_type = operator_export_type 1249 1250 with exporter_context(model, training, verbose): 1251 val_keep_init_as_ip = _decide_keep_init_as_input( 1252 keep_initializers_as_inputs, operator_export_type, opset_version 1253 ) 1254 val_add_node_names = _decide_add_node_names( 1255 add_node_names, operator_export_type 1256 ) 1257 val_do_constant_folding = _decide_constant_folding( 1258 do_constant_folding, operator_export_type, training 1259 ) 1260 args = _decide_input_format(model, args) 1261 graph, params_dict, torch_out = _model_to_graph( 1262 model, 1263 args, 1264 verbose, 1265 input_names, 1266 output_names, 1267 operator_export_type, 1268 val_do_constant_folding, 1269 training=training, 1270 dynamic_axes=dynamic_axes, 1271 ) 1272 1273 return graph._pretty_print_onnx( # type: ignore[attr-defined] 1274 params_dict, 1275 opset_version, 1276 False, 1277 operator_export_type, 1278 google_printer, 1279 val_keep_init_as_ip, 1280 custom_opsets, 1281 val_add_node_names, 1282 ) 1283 1284 1285@_deprecation.deprecated("2.5", "the future", "avoid using this function") 1286def unconvertible_ops( 1287 model, 1288 args, 1289 training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, 1290 opset_version: int | None = None, 1291) -> tuple[_C.Graph, list[str]]: 1292 """Returns an approximated list of all ops that are yet supported by :mod:`torch.onnx`. 1293 1294 The list is approximated because some ops may be removed during the conversion 1295 process and don't need to be converted. Some other ops may have partial support 1296 that will fail conversion with particular inputs. Please open a Github Issue 1297 for op support requests. 1298 1299 Args: 1300 model: Same as the `model` parameter in :func:`torch.onnx.export`. 1301 args: Same as the `args` parameter in :func:`torch.onnx.export`. 1302 training: Same as the `training` parameter in :func:`torch.onnx.export`. 1303 opset_version: Same as the `opset_version` parameter in :func:`torch.onnx.export`. 1304 1305 Returns: 1306 The JIT graph and a list of unconvertible ops in the format of "domain::op". 1307 """ 1308 1309 opset_version = opset_version or _constants.ONNX_DEFAULT_OPSET 1310 GLOBALS.export_onnx_opset_version = opset_version 1311 1312 try: 1313 with exporter_context(model, training, verbose=False): 1314 # Create a mostly clean JIT graph that contains the plain aten and 1315 # other ops we can check with the symbolic registry. 1316 # NOTE: We don't want to actually convert any ops to ONNX or run any 1317 # symbolic functions because there is a higher chance that a pass 1318 # fails or an unconvertible op messes up the graph during ONNX conversion. 1319 # This way we can always generate a list just by looking at the names 1320 # of the ops in the graph. 1321 args = _decide_input_format(model, args) 1322 model = _pre_trace_quant_model(model, args) 1323 graph, _, _, module = _create_jit_graph(model, args) 1324 _C._jit_pass_inline(graph) 1325 _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module) 1326 _C._jit_pass_erase_number_types(graph) 1327 _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) 1328 except Exception as e: 1329 raise errors.OnnxExporterError( 1330 "Failed to discover unconvertible ops because of errors during the JIT graph " 1331 "generation process." 1332 ) from e 1333 1334 unsupported_ops = [] 1335 for node in graph.nodes(): 1336 domain_op = node.kind() 1337 if domain_op.startswith(("onnx::", "prim::")): 1338 # We consider onnx and prim ops as supported ops, even though some "prim" 1339 # ops are not implemented as symbolic functions, because they may be 1340 # eliminated in the conversion passes. Users may still see errors caused 1341 # by prim ops even though they don't show up in the list. 1342 continue 1343 if not registration.registry.is_registered_op( 1344 domain_op.rstrip("_"), opset_version 1345 ): 1346 # We consider all registered ops supported, even though some of them are 1347 # only partially supported, because there is not yet a good way to check 1348 # if an op is fully supported. 1349 # TODO(justinchuby): Create a way to check if an op is fully supported. 1350 unsupported_ops.append(domain_op) 1351 return graph, unsupported_ops 1352 1353 1354def _setup_trace_module_map( 1355 model: torch.nn.Module | torch.jit.ScriptModule, 1356 export_modules_as_functions: bool | Collection[type[torch.nn.Module]], 1357) -> set[str]: 1358 def __register_attribute_hook(): 1359 attr_name = "_onnx_attrs" 1360 1361 def _track_module_attributes_forward_pre_hook(module, input): 1362 setattr(module, attr_name, _get_module_attributes(module)) 1363 1364 def _track_module_attributes_forward_hook(module, input, output): 1365 tracing_state = _C._get_tracing_state() 1366 if not tracing_state: 1367 return 1368 1369 graph = tracing_state.graph() 1370 onnx_attrs = {} 1371 if hasattr(module, attr_name): 1372 onnx_attrs = getattr(module, attr_name) 1373 delattr(module, attr_name) 1374 1375 _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs) 1376 1377 for m in model.modules(): 1378 m.register_forward_hook(_track_module_attributes_forward_hook) 1379 m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook) 1380 1381 def _unqualified_variable_name(qualified_name: str) -> str: 1382 """ 1383 Parse qualified variable name and return the unqualified version. 1384 1385 Pure numeric atoms are considered inadequate, so this function will look past them, 1386 and start from the first non-numeric atom. 1387 1388 Example: 1389 >>> _unqualified_variable_name("__main__.Foo.bar") 1390 'bar' 1391 >>> _unqualified_variable_name("__main__.Foo.bar.0") 1392 'bar.0' 1393 """ 1394 name_atoms = qualified_name.split(".") 1395 for i, atom in reversed(list(enumerate(name_atoms))): 1396 if not atom.isnumeric(): 1397 return ".".join(name_atoms[i:]) 1398 return qualified_name 1399 1400 trace_module_map = { 1401 _m: torch._C._jit_onnx_create_full_scope_name( 1402 torch.typename(type(_m)), _unqualified_variable_name(_n) 1403 ) 1404 for _n, _m in model.named_modules() 1405 } 1406 torch.jit._trace._trace_module_map = trace_module_map 1407 if isinstance(export_modules_as_functions, bool) and export_modules_as_functions: 1408 module_typenames = {torch.typename(type(module)) for module in trace_module_map} 1409 elif isinstance(export_modules_as_functions, set) and export_modules_as_functions: 1410 1411 def _find_typename(v): 1412 if isinstance(v, type): 1413 return torch.typename(v) 1414 else: 1415 raise RuntimeError( 1416 "Only type of the `nn.Module` should be " 1417 "passed in the set for argument `export_modules_as_functions`. " 1418 f"Got `{type(v).__name__}`." 1419 ) 1420 1421 module_typenames = {_find_typename(v) for v in export_modules_as_functions} 1422 else: 1423 module_typenames = set() 1424 1425 if module_typenames: 1426 __register_attribute_hook() 1427 1428 return module_typenames 1429 1430 1431def _reset_trace_module_map(): 1432 torch.jit._trace._trace_module_map = None 1433 _C._jit_pass_onnx_clear_scope_records() 1434 1435 1436def _get_module_attributes(module): 1437 annotations = typing.get_type_hints(type(module)) 1438 base_m_annotations = typing.get_type_hints(torch.nn.Module) 1439 [annotations.pop(k, None) for k in base_m_annotations] 1440 # Check whether module attributes can be accessed. Some classes 1441 # define attributes but don't provide access to them in their 1442 # constructor. 1443 # 1444 # For example, torch.nn.Embedding has the `freeze` variable and its 1445 # type specified in the class but the attribute is not created in the 1446 # constructor. In other words, there is no `self.freeze = <True | False>` 1447 # in the constructor. 1448 # 1449 # Reference: https://github.com/pytorch/pytorch/blob/92de1d322223fb5584e384971b32c46b93bc2f4b/torch/nn/modules/sparse.py#L120 1450 attrs = {} 1451 for k in annotations: 1452 try: 1453 attrs[k] = getattr(module, k) 1454 except AttributeError: 1455 torch.onnx.log(f"Skipping module attribute '{k}'") 1456 continue 1457 return attrs 1458 1459 1460def _export( 1461 model, 1462 args, 1463 f, 1464 export_params=True, 1465 verbose=False, 1466 training=_C_onnx.TrainingMode.EVAL, 1467 input_names=None, 1468 output_names=None, 1469 operator_export_type=_C_onnx.OperatorExportTypes.ONNX, 1470 export_type=None, 1471 opset_version=None, 1472 do_constant_folding=True, 1473 dynamic_axes=None, 1474 keep_initializers_as_inputs=None, 1475 fixed_batch_size=False, 1476 custom_opsets=None, 1477 add_node_names=True, 1478 onnx_shape_inference=True, 1479 export_modules_as_functions: Any = False, 1480 autograd_inlining=True, 1481): 1482 assert GLOBALS.in_onnx_export is False 1483 1484 if export_type is None: 1485 export_type = _exporter_states.ExportTypes.PROTOBUF_FILE 1486 1487 if isinstance(model, torch.nn.DataParallel): 1488 raise ValueError( 1489 "torch.nn.DataParallel is not supported by ONNX " 1490 "exporter, please use 'attribute' module to " 1491 "unwrap model from torch.nn.DataParallel. Try " 1492 "torch.onnx.export(model.module, ...)" 1493 ) 1494 1495 GLOBALS.onnx_shape_inference = onnx_shape_inference 1496 1497 if opset_version is None: 1498 opset_version = _constants.ONNX_DEFAULT_OPSET 1499 1500 # torch.onnx.export does not support opset versions >=18 1501 if opset_version > _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET: 1502 # We do not want to fail because we should still allow users to create 1503 # custom symbolic functions for opset>17 1504 warnings.warn( 1505 f"Exporting to ONNX opset version {opset_version} is not supported. " 1506 f"by 'torch.onnx.export()'. " 1507 f"The highest opset version supported is {_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET}. " 1508 f"To use a newer opset version, consider 'torch.onnx.export(..., dynamo=True)'. ", 1509 category=errors.OnnxExporterWarning, 1510 ) 1511 1512 if export_modules_as_functions and opset_version < 15: 1513 raise ValueError( 1514 "`export_modules_as_functions` is not supported for `opset_version` < 15." 1515 "This is because `opset_version` < 15 implies IR version < 8, which means " 1516 "no local function support. " 1517 ) 1518 if not operator_export_type: 1519 operator_export_type = _C_onnx.OperatorExportTypes.ONNX 1520 1521 # By default, training=TrainingMode.EVAL, 1522 # which is good because running a model in training mode could result in 1523 # internal buffers getting updated, dropout getting applied, etc. 1524 # If you really know what you're doing, you can turn 1525 # training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE, 1526 # (to preserve whatever the original training mode was.) 1527 GLOBALS.export_onnx_opset_version = opset_version 1528 GLOBALS.operator_export_type = operator_export_type 1529 1530 try: 1531 GLOBALS.in_onnx_export = True 1532 _autograd_inlining_previous = GLOBALS.autograd_inlining 1533 GLOBALS.autograd_inlining = autograd_inlining 1534 1535 module_typenames_to_export_as_functions: set[str] = set() 1536 if isinstance(model, (torch.nn.Module, torch.jit.ScriptModule)): 1537 module_typenames_to_export_as_functions = _setup_trace_module_map( 1538 model, export_modules_as_functions 1539 ) 1540 1541 with exporter_context(model, training, verbose): 1542 val_keep_init_as_ip = _decide_keep_init_as_input( 1543 keep_initializers_as_inputs, 1544 operator_export_type, 1545 opset_version, 1546 ) 1547 val_add_node_names = _decide_add_node_names( 1548 add_node_names, operator_export_type 1549 ) 1550 val_do_constant_folding = _decide_constant_folding( 1551 do_constant_folding, operator_export_type, training 1552 ) 1553 # Normally f can be a file-like object, but for large models, the external data format requires a 1554 # valid `model_file_location`. Code in export.cpp will enforce this. 1555 if isinstance(f, str): 1556 model_file_location = f 1557 else: 1558 model_file_location = "" 1559 args = _decide_input_format(model, args) 1560 if dynamic_axes is None: 1561 dynamic_axes = {} 1562 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) 1563 1564 graph, params_dict, torch_out = _model_to_graph( 1565 model, 1566 args, 1567 verbose, 1568 input_names, 1569 output_names, 1570 operator_export_type, 1571 val_do_constant_folding, 1572 fixed_batch_size=fixed_batch_size, 1573 training=training, 1574 dynamic_axes=dynamic_axes, 1575 ) 1576 1577 # TODO: Don't allocate a in-memory string for the protobuf 1578 defer_weight_export = ( 1579 export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE 1580 ) 1581 if custom_opsets is None: 1582 custom_opsets = {} 1583 1584 _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) 1585 node_attr_to_name = {} # type: ignore[var-annotated] 1586 if module_typenames_to_export_as_functions: 1587 # NOTE: cannot call DCE after this pass. DCE will remove function definition nodes. 1588 node_attr_to_name = _C._jit_pass_onnx_function_extraction( 1589 graph, 1590 module_typenames_to_export_as_functions, 1591 list(params_dict.keys()), 1592 ) 1593 1594 if keep_initializers_as_inputs is not True: 1595 params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment] 1596 graph, 1597 params_dict, # type: ignore[arg-type] 1598 getattr(model, "training", False), # type: ignore[arg-type] 1599 ) 1600 _C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph) 1601 if export_params: 1602 ( 1603 proto, 1604 export_map, 1605 val_use_external_data_format, 1606 node_names, 1607 ) = graph._export_onnx( # type: ignore[attr-defined] 1608 params_dict, 1609 opset_version, 1610 dynamic_axes, 1611 defer_weight_export, 1612 operator_export_type, 1613 not verbose, 1614 val_keep_init_as_ip, 1615 custom_opsets, 1616 val_add_node_names, 1617 model_file_location, 1618 node_attr_to_name, 1619 ) 1620 else: 1621 ( 1622 proto, 1623 export_map, 1624 val_use_external_data_format, 1625 node_names, 1626 ) = graph._export_onnx( # type: ignore[attr-defined] 1627 {}, 1628 opset_version, 1629 dynamic_axes, 1630 False, 1631 operator_export_type, 1632 not verbose, 1633 val_keep_init_as_ip, 1634 custom_opsets, 1635 val_add_node_names, 1636 model_file_location, 1637 node_attr_to_name, 1638 ) 1639 # insert function_proto into model_proto. 1640 proto = onnx_proto_utils._add_onnxscript_fn( 1641 proto, 1642 custom_opsets, 1643 ) 1644 if verbose: 1645 torch.onnx.log("Exported graph: ", graph) 1646 onnx_proto_utils._export_file(proto, f, export_type, export_map) 1647 finally: 1648 assert GLOBALS.in_onnx_export 1649 GLOBALS.in_onnx_export = False 1650 GLOBALS.autograd_inlining = _autograd_inlining_previous 1651 _reset_trace_module_map() 1652 1653 return torch_out 1654 1655 1656def _apply_friendly_debug_names(graph, params): 1657 for n in graph.nodes(): 1658 for v in n.inputs(): 1659 old_name = v.debugName() 1660 if old_name != str(v.unique()): 1661 continue 1662 new_name = f"{n.kind()}_{v.unique()}" 1663 v.setDebugName(new_name) 1664 if old_name in params: 1665 params[new_name] = params.pop(old_name) 1666 1667 1668def _set_input_and_output_names(graph, input_names, output_names): 1669 def set_names(node_list, name_list, descriptor): 1670 if name_list is None: 1671 return 1672 if len(name_list) > len(node_list): 1673 raise RuntimeError( 1674 "number of %s names provided (%d) exceeded number of %ss (%d)" 1675 % (descriptor, len(name_list), descriptor, len(node_list)) 1676 ) 1677 1678 # Mark if the output node DebugName is set before. 1679 output_node_set = set() 1680 for i, (name, node) in enumerate(zip(name_list, node_list)): 1681 # Duplicated output node, insert onnx::Identity to avoid setting the same DebugName after setDebugName(). 1682 if descriptor == "output": 1683 if node in output_node_set: 1684 identity_node = graph.create("onnx::Identity") 1685 identity_node.insertAfter(node.node()) 1686 identity_node.addInput(node) 1687 identity_node.output().setType(node.type()) 1688 graph.return_node().replaceInput(i, identity_node.output()) 1689 node = identity_node.output() 1690 output_node_set.add(node) 1691 1692 if node.debugName() != name: 1693 node.setDebugName(name) 1694 1695 set_names(list(graph.inputs()), input_names, "input") 1696 set_names(list(graph.outputs()), output_names, "output") 1697 1698 1699def _run_symbolic_method(g, op_name, symbolic_fn, args): 1700 r""" 1701 This trampoline function gets invoked for every symbolic method 1702 call from C++. 1703 """ 1704 try: 1705 graph_context = jit_utils.GraphContext( 1706 graph=g, 1707 block=g.block(), 1708 opset=GLOBALS.export_onnx_opset_version, 1709 original_node=None, # type: ignore[arg-type] 1710 params_dict=_params_dict, 1711 env={}, 1712 values_in_env=set(), 1713 new_nodes=[], 1714 ) 1715 return symbolic_fn(graph_context, *args) 1716 except TypeError as e: 1717 # Handle the specific case where we didn't successfully dispatch 1718 # to symbolic_fn. Otherwise, the backtrace will have the clues 1719 # you need. 1720 e.args = (f"{e.args[0]} (occurred when translating {op_name})",) 1721 raise 1722 1723 1724def _add_block(node: _C.Node) -> _C.Block: 1725 return node.addBlock() 1726 1727 1728def _add_input_to_block(block: _C.Block): 1729 return block.addInputToBlock() # type: ignore[attr-defined] 1730 1731 1732def _add_output_to_block(block: _C.Block, value: _C.Value) -> int: 1733 return block.registerOutput(value) 1734 1735 1736def _should_aten_fallback( 1737 name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes 1738): 1739 # For all builds, if domain=="aten" and operator_export_type==ONNX_ATEN, 1740 # an aten::ATen operator is created regardless of symbolics existence 1741 1742 is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version) 1743 is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN 1744 is_aten_fallback_export = ( 1745 operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK 1746 ) 1747 1748 if not name.startswith("aten::"): 1749 return False 1750 1751 if is_onnx_aten_export or (is_aten_fallback_export and not is_exportable_aten_op): 1752 return True 1753 1754 return False 1755 1756 1757def _get_aten_op_overload_name(n: _C.Node) -> str: 1758 # Returns `overload_name` attribute to ATen ops on non-Caffe2 builds 1759 schema = n.schema() 1760 if not schema.startswith("aten::"): 1761 return "" 1762 return _C.parse_schema(schema).overload_name 1763 1764 1765def _run_symbolic_function( 1766 graph: _C.Graph, 1767 block: _C.Block, 1768 node: _C.Node, 1769 inputs: Any, 1770 env: dict[_C.Value, _C.Value], 1771 values_in_env: set[_C.Value], 1772 new_nodes: list[_C.Node], 1773 operator_export_type=_C_onnx.OperatorExportTypes.ONNX, 1774) -> _C.Value | Sequence[_C.Value | None] | None: 1775 """Runs a symbolic function. 1776 1777 The function is used in C++ to export the node to ONNX. 1778 1779 Returns: 1780 A single or a tuple of Values. 1781 None when the node gets cloned as is into the new graph. 1782 """ 1783 1784 opset_version = GLOBALS.export_onnx_opset_version 1785 1786 # See Note [Export inplace] 1787 node_kind = node.kind() 1788 if node_kind.endswith("_"): 1789 # Treat relu_ -> relu; add_ -> add etc. 1790 ns_op_name = node_kind[:-1] 1791 else: 1792 ns_op_name = node_kind 1793 1794 namespace, op_name = jit_utils.parse_node_kind(ns_op_name) 1795 1796 graph_context = jit_utils.GraphContext( 1797 graph=graph, 1798 block=block, 1799 opset=opset_version, 1800 original_node=node, 1801 params_dict=_params_dict, 1802 env=env, 1803 values_in_env=values_in_env, 1804 new_nodes=new_nodes, 1805 ) 1806 1807 # Direct ATen export requested 1808 if _should_aten_fallback(ns_op_name, opset_version, operator_export_type): 1809 attrs = { 1810 k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) 1811 for k in node.attributeNames() 1812 } 1813 outputs = node.outputsSize() 1814 attrs["outputs"] = outputs 1815 return graph_context.aten_op( 1816 op_name, 1817 *inputs, 1818 overload_name=_get_aten_op_overload_name(node), 1819 **attrs, 1820 ) 1821 1822 try: 1823 domain = namespace 1824 symbolic_function_name = f"{domain}::{op_name}" 1825 1826 symbolic_function_group = registration.registry.get_function_group( 1827 symbolic_function_name 1828 ) 1829 if symbolic_function_group is not None: 1830 symbolic_fn = symbolic_function_group.get(opset_version) 1831 if symbolic_fn is not None: 1832 # TODO Wrap almost identical attrs assignment or comment the difference. 1833 attrs = { 1834 k: symbolic_helper._node_get(node, k) for k in node.attributeNames() 1835 } 1836 return symbolic_fn(graph_context, *inputs, **attrs) 1837 1838 attrs = { 1839 k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) 1840 for k in node.attributeNames() 1841 } 1842 if namespace == "onnx": 1843 # Clone node to trigger ONNX shape inference 1844 return graph_context.op( 1845 op_name, *inputs, **attrs, outputs=node.outputsSize() 1846 ) # type: ignore[attr-defined] 1847 1848 raise errors.UnsupportedOperatorError( 1849 symbolic_function_name, 1850 opset_version, 1851 symbolic_function_group.get_min_supported() 1852 if symbolic_function_group 1853 else None, 1854 ) 1855 1856 except RuntimeError: 1857 if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH: 1858 return None 1859 elif operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: 1860 # Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK` 1861 attrs = { 1862 k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) 1863 for k in node.attributeNames() 1864 } 1865 return graph_context.aten_op( 1866 op_name, 1867 *inputs, 1868 overload_name=_get_aten_op_overload_name(node), 1869 **attrs, 1870 ) 1871 raise 1872 except TypeError as e: 1873 # Handle the specific case where we didn't successfully dispatch. 1874 # Otherwise, the backtrace will have the clues you need. 1875 e.args = (f"{e.args[0]} \n(Occurred when translating {op_name}).",) 1876 raise 1877 1878 1879def _verify_custom_op_name(symbolic_name: str): 1880 if not re.match(r"^[a-zA-Z0-9-_]+::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name): 1881 raise errors.OnnxExporterError( 1882 f"Failed to register operator {symbolic_name}. " 1883 "The symbolic name must match the format domain::name, " 1884 "and should start with a letter and contain only " 1885 "alphanumerical characters" 1886 ) 1887 1888 ns, _ = jit_utils.parse_node_kind(symbolic_name) 1889 if ns == "onnx": 1890 raise ValueError( 1891 f"Failed to register operator {symbolic_name}. {ns} domain cannot be modified." 1892 ) 1893 1894 1895def register_custom_op_symbolic( 1896 symbolic_name: str, 1897 symbolic_fn: Callable, 1898 opset_version: int, 1899): 1900 """Registers a symbolic function for a custom operator. 1901 1902 When the user registers symbolic for custom/contrib ops, 1903 it is highly recommended to add shape inference for that operator via setType API, 1904 otherwise the exported graph may have incorrect shape inference in some extreme cases. 1905 An example of setType is `test_aten_embedding_2` in `test_operators.py`. 1906 1907 See "Custom Operators" in the module documentation for an example usage. 1908 1909 Args: 1910 symbolic_name (str): The name of the custom operator in "<domain>::<op>" 1911 format. 1912 symbolic_fn (Callable): A function that takes in the ONNX graph and 1913 the input arguments to the current operator, and returns new 1914 operator nodes to add to the graph. 1915 opset_version (int): The ONNX opset version in which to register. 1916 """ 1917 if symbolic_name.startswith("::"): 1918 symbolic_name = f"aten{symbolic_name}" 1919 1920 _verify_custom_op_name(symbolic_name) 1921 1922 registration.custom_onnx_symbolic(symbolic_name, opset_version)(symbolic_fn) 1923 1924 1925def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int): 1926 """Unregisters ``symbolic_name``. 1927 1928 See "Custom Operators" in the module documentation for an example usage. 1929 1930 Args: 1931 symbolic_name (str): The name of the custom operator in "<domain>::<op>" 1932 format. 1933 opset_version (int): The ONNX opset version in which to unregister. 1934 """ 1935 if symbolic_name.startswith("::"): 1936 symbolic_name = f"aten{symbolic_name}" 1937 1938 _verify_custom_op_name(symbolic_name) 1939 1940 registration.registry.unregister(symbolic_name, opset_version) 1941 1942 1943def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names): 1944 """Ensures dynamic axes argument is follows the expected format.""" 1945 if len(dynamic_axes) == 0: 1946 return 1947 1948 if hasattr(model, "graph"): 1949 # Extracting set of valid input/output names that shall be used for dynamic_axes 1950 if (input_names is None) or len(input_names) == 0: 1951 input_names = [x.debugName() for x in model.graph.inputs()] 1952 if (output_names is None) or len(output_names) == 0: 1953 output_names = [y.debugName() for y in model.graph.outputs()] 1954 1955 valid_names = set((input_names or []) + (output_names or [])) 1956 1957 # If dynamic axes are provided as a list rather than dictionary, they should 1958 # first get converted to a dictionary in expected format. If desired axes names 1959 # are not provided for dynamic axes, automatic names shall be generated for 1960 # provided dynamic axes of specified input/output 1961 for key, value in dynamic_axes.items(): 1962 if key not in valid_names: 1963 warnings.warn( 1964 f"Provided key {key} for dynamic axes is not a valid input/output name" 1965 ) 1966 if isinstance(value, list): 1967 warnings.warn( 1968 "No names were found for specified dynamic axes of provided input." 1969 f"Automatically generated names will be applied to each dynamic axes of input {key}" 1970 ) 1971 1972 value_dict = {} 1973 for i, x in enumerate(value): 1974 if not isinstance(x, int): 1975 raise ValueError( 1976 "The type of axis index is expected to be an integer" 1977 ) 1978 if x in value_dict: 1979 warnings.warn( 1980 f"Duplicate dynamic axis index {x} was provided for input {key}." 1981 ) 1982 else: 1983 value_dict[x] = str(key) + "_dynamic_axes_" + str(i + 1) 1984 dynamic_axes[key] = value_dict 1985 1986 1987def model_signature(model: torch.nn.Module | Callable) -> inspect.Signature: 1988 return inspect.signature( 1989 model.forward if isinstance(model, torch.nn.Module) else model 1990 ) 1991