xref: /aosp_15_r20/external/pytorch/torch/onnx/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Functions to export models into the ONNX IR format.
3
4These models can be loaded with the ONNX library and then
5converted to models which run on other deep learning frameworks.
6"""
7
8from __future__ import annotations
9
10import contextlib
11import copy
12import inspect
13import re
14import typing
15import warnings
16from typing import Any, Callable, cast, Collection, Mapping, Sequence
17
18import torch
19import torch._C._onnx as _C_onnx
20import torch.jit._trace
21import torch.serialization
22from torch import _C
23from torch.onnx import (  # noqa: F401
24    _constants,
25    _deprecation,
26    _exporter_states,
27    errors,
28    symbolic_helper,
29)
30from torch.onnx._globals import GLOBALS
31from torch.onnx._internal import diagnostics, jit_utils, onnx_proto_utils, registration
32
33
34__all__ = [
35    "is_in_onnx_export",
36    "select_model_mode_for_export",
37    "disable_apex_o2_state_dict_hook",
38    "setup_onnx_logging",
39    "exporter_context",
40    "export",
41    "model_signature",
42    "warn_on_static_input_change",
43    "unpack_quantized_tensor",
44    "export_to_pretty_string",
45    "unconvertible_ops",
46    "register_custom_op_symbolic",
47    "unregister_custom_op_symbolic",
48]
49
50
51def is_in_onnx_export() -> bool:
52    """Returns whether it is in the middle of ONNX export."""
53    return GLOBALS.in_onnx_export
54
55
56# TODO(justinchuby): Remove dependency to this global variable from constant_fold.cpp
57# Skip check due to cannot import IValue from torch._C
58_params_dict = {}  # type: ignore[var-annotated]
59
60
61@contextlib.contextmanager
62def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
63    r"""A context manager to temporarily set the training mode of ``model``
64    to ``mode``, resetting it when we exit the with-block.
65
66    Args:
67        model: Same type and meaning as ``model`` arg to :func:`export`.
68        mode: Same type and meaning as ``training`` arg to :func:`export`.
69    """
70    if not isinstance(mode, _C_onnx.TrainingMode):
71        raise TypeError(
72            f"'mode' should be a torch.onnx.TrainingMode enum, but got '{type(mode)}'."
73        )
74    originally_training: bool = False
75
76    if hasattr(model, "training"):
77        originally_training = model.training
78
79        # ONNX opset 12 has better support for training amenable models, with updated
80        # versions of the dropout and batch_norm operators
81        if mode == _C_onnx.TrainingMode.TRAINING or (
82            mode == _C_onnx.TrainingMode.PRESERVE and originally_training
83        ):
84            GLOBALS.export_training = True
85            if GLOBALS.export_onnx_opset_version < 12:
86                warnings.warn(
87                    "You are exporting the model in training mode with onnx opset "
88                    f"version {GLOBALS.export_onnx_opset_version}. "
89                    "Opset versions lower than opset 12 will not be able to export "
90                    "nodes such as Dropout and BatchNorm correctly."
91                )
92        else:
93            GLOBALS.export_training = False
94
95        GLOBALS.training_mode = mode
96        if mode == _C_onnx.TrainingMode.TRAINING:
97            model.train(True)
98        elif mode == _C_onnx.TrainingMode.EVAL:
99            model.train(False)
100        # else mode == _C_onnx.TrainingMode.PRESERVE, do nothing
101
102    try:
103        yield
104    finally:
105        if hasattr(model, "training") and not mode == _C_onnx.TrainingMode.PRESERVE:
106            model.train(originally_training)
107
108
109@contextlib.contextmanager
110def disable_apex_o2_state_dict_hook(model: torch.nn.Module | torch.jit.ScriptFunction):
111    # Apex O2 hook state_dict to return fp16 weights as fp32.
112    # Exporter cannot identify them as same tensors.
113    # Since this hook is only used by optimizer, it is safe to
114    # remove this hook while exporting.
115    if not isinstance(model, torch.jit.ScriptFunction):
116        model_hooks = {}  # type: ignore[var-annotated]
117        for module in model.modules():
118            for key, hook in module._state_dict_hooks.items():
119                if type(hook).__name__ == "O2StateDictHook":
120                    if module not in model_hooks:
121                        model_hooks[module] = {}
122                    model_hooks[module][key] = hook
123            if module in model_hooks:
124                for key in model_hooks[module]:
125                    module._state_dict_hooks.pop(key)
126        try:
127            yield
128        finally:
129            # Add the hooks back
130            for module, m_map in model_hooks.items():
131                for key, hook in m_map.items():
132                    module._state_dict_hooks[key] = hook
133    else:
134        try:
135            yield
136        finally:
137            pass
138
139
140@contextlib.contextmanager
141def setup_onnx_logging(verbose: bool):
142    is_originally_enabled = torch.onnx.is_onnx_log_enabled()
143    if is_originally_enabled or verbose:
144        torch.onnx.enable_log()
145    try:
146        yield
147    finally:
148        if not is_originally_enabled:
149            torch.onnx.disable_log()
150
151
152@contextlib.contextmanager
153def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool):
154    with select_model_mode_for_export(
155        model, mode
156    ) as mode_ctx, disable_apex_o2_state_dict_hook(
157        model
158    ) as apex_ctx, setup_onnx_logging(
159        verbose
160    ) as log_ctx, diagnostics.create_export_diagnostic_context() as diagnostic_ctx:
161        yield (mode_ctx, apex_ctx, log_ctx, diagnostic_ctx)
162
163
164def _get_torch_export_args(
165    args: tuple[Any, ...],
166    kwargs: dict[str, Any] | None,
167) -> tuple[tuple[Any, ...], dict[str, Any] | None]:
168    """Obtain the arguments for torch.onnx.export from the model and the input arguments."""
169    if not kwargs and args and isinstance(args[-1], dict):
170        kwargs = args[-1]
171        args = args[:-1]
172    return args, kwargs
173
174
175def export(
176    model: torch.nn.Module | torch.jit.ScriptModule | torch.jit.ScriptFunction,
177    args: tuple[Any, ...] | torch.Tensor,
178    f: str,
179    *,
180    kwargs: dict[str, Any] | None = None,
181    export_params: bool = True,
182    verbose: bool = False,
183    training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
184    input_names: Sequence[str] | None = None,
185    output_names: Sequence[str] | None = None,
186    operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX,
187    opset_version: int | None = None,
188    do_constant_folding: bool = True,
189    dynamic_axes: Mapping[str, Mapping[int, str]]
190    | Mapping[str, Sequence[int]]
191    | None = None,
192    keep_initializers_as_inputs: bool | None = None,
193    custom_opsets: Mapping[str, int] | None = None,
194    export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False,
195    autograd_inlining: bool = True,
196) -> None:
197    r"""Exports a model into ONNX format.
198
199    If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
200    :class:`torch.jit.ScriptFunction`, this runs
201    ``model`` once in order to convert it to a TorchScript graph to be exported
202    (the equivalent of :func:`torch.jit.trace`). Thus this has the same limited support
203    for dynamic control flow as :func:`torch.jit.trace`.
204
205    Args:
206        model: The model to be exported.
207        args:
208
209            args can be structured either as:
210
211            1. ONLY A TUPLE OF ARGUMENTS::
212
213                args = (x, y, z)
214
215            The tuple should contain model inputs such that ``model(*args)`` is a valid
216            invocation of the model. Any non-Tensor arguments will be hard-coded into the
217            exported model; any Tensor arguments will become inputs of the exported model,
218            in the order they occur in the tuple.
219
220            2. A TENSOR::
221
222                args = torch.Tensor([1])
223
224            This is equivalent to a 1-ary tuple of that Tensor.
225
226            3. A TUPLE OF ARGUMENTS ENDING WITH A DICTIONARY OF NAMED ARGUMENTS::
227
228                args = (x, {"y": input_y, "z": input_z})
229
230            All but the last element of the tuple will be passed as non-keyword arguments,
231            and named arguments will be set from the last element. If a named argument is
232            not present in the dictionary, it is assigned the default value, or None if a
233            default value is not provided.
234
235            .. warning::
236                This behavior will be deprecated in a future release. Please use the
237                kwargs argument instead.
238
239            .. note::
240                If a dictionary is the last element of the args tuple, it will be
241                interpreted as containing named arguments. In order to pass a dict as the
242                last non-keyword arg, provide an empty dict as the last element of the args
243                tuple. For example, instead of::
244
245                    torch.onnx.export(
246                        model,
247                        (
248                            x,
249                            # WRONG: will be interpreted as named arguments
250                            {y: z},
251                        ),
252                        "test.onnx.pb",
253                    )
254
255                Write::
256
257                    torch.onnx.export(model, (x, {y: z}, {}), "test.onnx.pb")
258
259        f: Path to the output ONNX model file. E.g. "model.onnx".
260        kwargs: Named arguments to the model.
261        export_params: If True, all parameters will
262            be exported. Set this to False if you want to export an untrained model.
263            In this case, the exported model will first take all of its parameters
264            as arguments, with the ordering as specified by ``model.state_dict().values()``
265        verbose: if True, prints a description of the
266            model being exported to stdout. In addition, the final ONNX graph will include the
267            field ``doc_string``` from the exported model which mentions the source code locations
268            for ``model``. If True, ONNX exporter logging will be turned on.
269        training:
270            * ``TrainingMode.EVAL``: export the model in inference mode.
271            * ``TrainingMode.PRESERVE``: export the model in inference mode if model.training is
272                False and in training mode if model.training is True.
273            * ``TrainingMode.TRAINING``: export the model in training mode. Disables optimizations
274                which might interfere with training.
275        input_names (list of str, default empty list): names to assign to the
276            input nodes of the graph, in order.
277        output_names (list of str, default empty list): names to assign to the
278            output nodes of the graph, in order.
279        operator_export_type (enum, default OperatorExportTypes.ONNX):
280
281            .. warning::
282                This option will be deprecated in a future release. Future exported
283                graphs will always use the default opset domain.
284
285            * ``OperatorExportTypes.ONNX``: Export all ops as regular ONNX ops
286                (in the default opset domain).
287            * ``OperatorExportTypes.ONNX_FALLTHROUGH``: Try to convert all ops
288                to standard ONNX ops in the default opset domain. If unable to do so
289                (e.g. because support has not been added to convert a particular torch op to ONNX),
290                fall back to exporting the op into a custom opset domain without conversion. Applies
291                to `custom ops <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_
292                as well as ATen ops. For the exported model to be usable, the runtime must support
293                these non-standard ops.
294            * ``OperatorExportTypes.ONNX_ATEN``: All ATen ops (in the TorchScript namespace "aten")
295                are exported as ATen ops (in opset domain "org.pytorch.aten").
296                `ATen <https://pytorch.org/cppdocs/#aten>`_ is PyTorch's built-in tensor library, so
297                this instructs the runtime to use PyTorch's implementation of these ops.
298
299                .. warning::
300
301                    Models exported this way are probably runnable only by Caffe2.
302
303                    This may be useful if the numeric differences in implementations of operators are
304                    causing large differences in behavior between PyTorch and Caffe2 (which is more
305                    common on untrained models).
306
307            * ``OperatorExportTypes.ONNX_ATEN_FALLBACK``: Try to export each ATen op
308                (in the TorchScript namespace "aten") as a regular ONNX op. If we are unable to do so
309                (e.g. because support has not been added to convert a particular torch op to ONNX),
310                fall back to exporting an ATen op. See documentation on OperatorExportTypes.ONNX_ATEN for
311                context.
312                For example::
313
314                    graph(%0 : Float):
315                    %3 : int = prim::Constant[value=0]()
316                    # conversion unsupported
317                    %4 : Float = aten::triu(%0, %3)
318                    # conversion supported
319                    %5 : Float = aten::mul(%4, %0)
320                    return (%5)
321
322                Assuming ``aten::triu`` is not supported in ONNX, this will be exported as::
323
324                    graph(%0 : Float):
325                    %1 : Long() = onnx::Constant[value={0}]()
326                    # not converted
327                    %2 : Float = aten::ATen[operator="triu"](%0, %1)
328                    # converted
329                    %3 : Float = onnx::Mul(%2, %0)
330                    return (%3)
331
332                .. warning::
333
334                    Models exported this way are probably runnable only by Caffe2.
335
336        opset_version (int, default 17): The version of the
337            `default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
338            to target. Must be >= 7 and <= 17.
339        do_constant_folding: Apply the constant-folding optimization.
340            Constant-folding will replace some of the ops that have all constant inputs
341            with pre-computed constant nodes.
342        dynamic_axes:
343
344            By default the exported model will have the shapes of all input and output tensors
345            set to exactly match those given in ``args``. To specify axes of tensors as
346            dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema:
347
348            * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or
349                ``output_names``.
350            * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a
351                list, each element is an axis index.
352
353            For example::
354
355                class SumModule(torch.nn.Module):
356                    def forward(self, x):
357                        return torch.sum(x, dim=1)
358
359
360                torch.onnx.export(
361                    SumModule(),
362                    (torch.ones(2, 2),),
363                    "onnx.pb",
364                    input_names=["x"],
365                    output_names=["sum"],
366                )
367
368            Produces::
369
370                input {
371                  name: "x"
372                  ...
373                      shape {
374                        dim {
375                          dim_value: 2  # axis 0
376                        }
377                        dim {
378                          dim_value: 2  # axis 1
379                ...
380                output {
381                  name: "sum"
382                  ...
383                      shape {
384                        dim {
385                          dim_value: 2  # axis 0
386                ...
387
388            While::
389
390                torch.onnx.export(
391                    SumModule(),
392                    (torch.ones(2, 2),),
393                    "onnx.pb",
394                    input_names=["x"],
395                    output_names=["sum"],
396                    dynamic_axes={
397                        # dict value: manually named axes
398                        "x": {0: "my_custom_axis_name"},
399                        # list value: automatic names
400                        "sum": [0],
401                    },
402                )
403
404            Produces::
405
406                input {
407                  name: "x"
408                  ...
409                      shape {
410                        dim {
411                          dim_param: "my_custom_axis_name"  # axis 0
412                        }
413                        dim {
414                          dim_value: 2  # axis 1
415                ...
416                output {
417                  name: "sum"
418                  ...
419                      shape {
420                        dim {
421                          dim_param: "sum_dynamic_axes_1"  # axis 0
422                ...
423
424        keep_initializers_as_inputs: If True, all the
425            initializers (typically corresponding to parameters) in the
426            exported graph will also be added as inputs to the graph. If False,
427            then initializers are not added as inputs to the graph, and only
428            the non-parameter inputs are added as inputs.
429            This may allow for better optimizations (e.g. constant folding) by
430            backends/runtimes.
431
432            If True, `deduplicate_initializers` pass will not be executed. This means
433            initializers with duplicated values will not be deduplicated and
434            will be treated as distinct inputs to the graph. This allows different
435            input initializers to be supplied at the runtime following export.
436
437            If ``opset_version < 9``, initializers MUST be part of graph
438            inputs and this argument will be ignored and the behavior will be
439            equivalent to setting this argument to True.
440
441        custom_opsets (dict[str, int], default empty dict): A dict with schema:
442
443            * KEY (str): opset domain name
444            * VALUE (int): opset version
445
446            If a custom opset is referenced by ``model`` but not mentioned in this dictionary,
447            the opset version is set to 1. Only custom opset domain name and version should be
448            indicated through this argument.
449
450        export_modules_as_functions: Flag to enable
451            exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the
452            particular types of modules to export as local functions in ONNX.
453            This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because
454            ``opset_version`` < 15 implies IR version < 8, which means no local function support.
455            Module variables will be exported as function attributes. There are two categories of function
456            attributes.
457
458            1. Annotated attributes: class variables that have type annotations via
459            `PEP 526-style <https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations>`_
460            will be exported as attributes.
461            Annotated attributes are not used inside the subgraph of ONNX local function because
462            they are not created by PyTorch JIT tracing, but they may be used by consumers
463            to determine whether or not to replace the function with a particular fused kernel.
464
465            2. Inferred attributes: variables that are used by operators inside the module. Attribute names
466            will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from
467            python module annotations. Inferred attributes are used inside the subgraph of ONNX local function.
468
469            * ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes.
470            * ``True``: export all ``nn.Module`` forward calls as local function nodes.
471            * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes,
472                only if the type of the ``nn.Module`` is found in the set.
473
474        autograd_inlining: Flag used to control whether to inline autograd functions.
475            Refer to https://github.com/pytorch/pytorch/pull/74765 for more details.
476
477    Raises:
478        :class:`torch.onnx.errors.CheckerError`: If the ONNX checker detects an invalid ONNX graph.
479        :class:`torch.onnx.errors.UnsupportedOperatorError`: If the ONNX graph cannot be exported because it
480            uses an operator that is not supported by the exporter.
481        :class:`torch.onnx.errors.OnnxExporterError`: Other errors that can occur during export.
482            All errors are subclasses of :class:`errors.OnnxExporterError`.
483    """
484    if operator_export_type != _C_onnx.OperatorExportTypes.ONNX:
485        warnings.warn(
486            "Setting `operator_export_type` to something other than default is deprecated. "
487            "The option will be removed in a future release.",
488            category=FutureWarning,
489        )
490    if training == _C_onnx.TrainingMode.TRAINING:
491        warnings.warn(
492            "Setting `training` to something other than default is deprecated. "
493            "The option will be removed in a future release. Please set the training mode "
494            "before exporting the model.",
495            category=FutureWarning,
496        )
497
498    args = (args,) if isinstance(args, torch.Tensor) else args
499    if kwargs is not None:
500        args = args + (kwargs,)
501
502    _export(
503        model,
504        args,
505        f,
506        export_params,
507        verbose,
508        training,
509        input_names,
510        output_names,
511        operator_export_type=operator_export_type,
512        opset_version=opset_version,
513        do_constant_folding=do_constant_folding,
514        dynamic_axes=dynamic_axes,
515        keep_initializers_as_inputs=keep_initializers_as_inputs,
516        custom_opsets=custom_opsets,
517        export_modules_as_functions=export_modules_as_functions,
518        autograd_inlining=autograd_inlining,
519    )
520
521    return None
522
523
524def _is_constant_tensor_list(node):
525    if node.kind() != "prim::Constant":
526        return False
527    output_type = node.output().type()
528    if output_type.isSubtypeOf(_C.ListType.ofTensors()):
529        return True
530    if output_type.isSubtypeOf(_C.ListType(_C.OptionalType.ofTensor())):
531        return True
532
533
534# ONNX can't handle constants that are lists of tensors, which can
535# get generated in constant prop. So we split them back into prim::ListConstructs
536
537
538def _split_tensor_list_constants(g, block):
539    for node in block.nodes():
540        for subblock in node.blocks():
541            _split_tensor_list_constants(g, subblock)
542        if _is_constant_tensor_list(node):
543            inputs = []
544            for val in node.output().toIValue():
545                input = g.insertConstant(val)
546                input.node().moveBefore(node)
547                input.node().copyMetadata(node)
548                inputs.append(input)
549
550            lc = (
551                g.create("prim::ListConstruct", inputs)
552                .insertBefore(node)
553                .output()
554                .setType(_C.ListType.ofTensors())
555            )
556            lc.node().copyMetadata(node)
557            node.output().replaceAllUsesWith(lc)
558
559
560def _optimize_graph(
561    graph: _C.Graph,
562    operator_export_type: _C_onnx.OperatorExportTypes,
563    _disable_torch_constant_prop: bool = False,
564    fixed_batch_size: bool = False,
565    params_dict=None,
566    dynamic_axes=None,
567    input_names=None,
568    module=None,
569):
570    if params_dict is None:
571        params_dict = {}
572
573    # Inline everything
574    _C._jit_pass_inline(graph)
575
576    # Remove fork/wait nodes
577    _C._jit_pass_inline_fork_wait(graph)
578    _C._jit_pass_lint(graph)
579    if GLOBALS.autograd_inlining:
580        _C._jit_pass_onnx_autograd_function_process(graph)
581    _C._jit_pass_lower_all_tuples(graph)
582
583    # we now record some ops like ones/zeros
584    # into a trace where we previously recorded constants.
585    # use constant prop to maintain our current level of onnx support
586    # without implementing symbolics for all of them
587    if _disable_torch_constant_prop is False:
588        _C._jit_pass_constant_propagation(graph)
589
590    _split_tensor_list_constants(graph, graph)
591    # run dce to eliminate dead parts of the graph that might have been
592    # left behind by things like symbolic_override
593    _C._jit_pass_dce(graph)
594    _C._jit_pass_lint(graph)
595
596    # CSE should improve perf when Autocast is used with disabled cache
597    # Autocast is disabled due to a limitation on tracer as described at https://github.com/pytorch/pytorch/issues/84092
598    # Must run before _C._jit_pass_erase_number_types to prevent type substitution
599    if _C._jit_pass_cse(graph):
600        _C._jit_pass_onnx_lint(graph)
601
602    _C._jit_pass_canonicalize_graph_fuser_ops(graph)
603    _C._jit_pass_lint(graph)
604    _C._jit_pass_peephole(graph, True)
605    _C._jit_pass_fuse_addmm(graph)
606    _C._jit_pass_lint(graph)
607
608    _C._jit_pass_peephole(graph, True)
609    _C._jit_pass_lower_all_tuples(graph)
610    # in _jit_pass_onnx, symbolic functions are called for each node for conversion.
611    # However, there are nodes that cannot be converted without additional context.
612    # For example, the number of outputs from split (and whether it is static or dynamic) is unknown
613    # until the point where it is unpacked by listUnpack node.
614    # This pass does a preprocess, and prepares the nodes such that enough context can be received
615    # by the symbolic function.
616    _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module)
617    _C._jit_pass_onnx_preprocess(graph)
618
619    # onnx does not support tuples, so try to remove them
620    _C._jit_pass_lint(graph)
621
622    # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
623    _C._jit_pass_prepare_division_for_onnx(graph)
624
625    _C._jit_pass_onnx_remove_print(graph)
626    _C._jit_pass_onnx_preprocess_caffe2(graph)
627
628    symbolic_helper._quantized_ops.clear()
629    # Unpack quantized weights for conv and linear ops and insert into graph.
630    _C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict)
631    # onnx only supports tensors, so we turn all out number types into tensors
632    _C._jit_pass_erase_number_types(graph)
633    if GLOBALS.onnx_shape_inference:
634        input_names = [] if input_names is None else input_names
635        dynamic_axes = {} if dynamic_axes is None else dynamic_axes
636        _C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names)
637    _C._jit_pass_onnx_lint(graph)
638
639    graph = _C._jit_pass_onnx(graph, operator_export_type)
640    _C._jit_pass_onnx_lint(graph)
641    _C._jit_pass_lint(graph)
642
643    _C._jit_pass_onnx_scalar_type_analysis(
644        graph, True, GLOBALS.export_onnx_opset_version
645    )
646    _C._jit_pass_lint(graph)
647
648    _C._jit_pass_onnx_peephole(
649        graph, GLOBALS.export_onnx_opset_version, fixed_batch_size
650    )
651    _C._jit_pass_lint(graph)
652
653    # graph is not a valid jit graph anymore because types have been replaced
654    # (e.g. int with Tensor), so it now contains operators that don't actually
655    # exist. We can't run normal dead code elimination because it'd fail trying
656    # to look up if an operator has side effects, but we can run a dead code
657    # elimination variant that doesn't need to look up if an op has side effects.
658    _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
659    _C._jit_pass_lint(graph)
660    graph = _C._jit_pass_canonicalize(graph)
661    _C._jit_pass_lint(graph)
662    if GLOBALS.onnx_shape_inference:
663        _C._jit_pass_onnx_graph_shape_type_inference(
664            graph, params_dict, GLOBALS.export_onnx_opset_version
665        )
666
667    return graph
668
669
670def warn_on_static_input_change(input_states):
671    """Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph.
672
673    We accept dictionaries and strings as ONNX inputs, but they should be only for
674    configuration use. we detect here if these inputs are modified, and if so we warn
675    the user that the changes won't take effect in the traced ONNX graph.
676    """
677    for input, traced_input in zip(input_states[0], input_states[1]):
678        if isinstance(input, dict):
679            if list(input.keys()) != list(traced_input.keys()):
680                warning = (
681                    "We detected that you are modifying a dictionary that is an input to your "
682                    "model. "
683                    "Note that dictionaries are allowed as inputs in ONNX but they should be "
684                    "handled with care. "
685                    "Usages of dictionaries is not recommended, and should not be used except "
686                    "for configuration use. "
687                    "Also note that the order and values of the keys must remain the same. "
688                )
689                warnings.warn(warning)
690        elif isinstance(input, str):
691            if input != traced_input:
692                warning = (
693                    "The model seems to have string inputs/outputs. "
694                    "Note that strings will not appear as inputs/outputs of the ONNX graph. "
695                )
696                warnings.warn(warning)
697
698
699def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
700    """Resolves the arguments that are ignored when export_type != operator_export_type.ONNX."""
701    return arg_value
702
703
704def _decide_keep_init_as_input(
705    keep_initializers_as_inputs: bool | None,
706    operator_export_type: _C_onnx.OperatorExportTypes,
707    opset_version: int,
708):
709    """Decides whether the initializers in the graph should be listed as ONNX graph inputs.
710
711    This method encapsulates the logic to decide whether the initializers in the graph
712    should be listed as ONNX graph inputs (i.e., whether to choose ONNX IR v3 or v4).
713    If keep_initializers_as_inputs is not specified (None), then we decide whether to keep
714    initializers as graph inputs (val_keep_init_as_ip) based on export type. If export type
715    is ONNX, then do not keep initializers as input (val_keep_init_as_ip=False). For all other
716    export types keep initializers as input (val_keep_init_as_ip=True).
717    If keep_initializers_as_inputs is specified, then respect it. Unless opset version <= 8,
718    in which case it must be ignored because for opset version <= 8, all initializers MUST be
719    part of graph input (only ONNX IR v3 is allowed), i.e. val_keep_init_as_ip=True.
720
721    Special handling is needed for opset version 8 or lower, because irrespective
722    of user input for keep_initializers_as_inputs, the graph must follow ONNX IR v3
723    semantics, i.e. all initializers must be listed as ONNX graph input.
724    """
725
726    if opset_version < 9:
727        if keep_initializers_as_inputs is False:
728            warnings.warn(
729                "Setting 'keep_initializers_as_inputs=False' for opset version"
730                "8 or lower would lead to an invalid ONNX graph. Therefore, "
731                "'keep_initializers_as_inputs=False' is ignored during export."
732                "Exported model will have initializers as graph inputs (compliant "
733                " to ONNX IR v3)."
734            )
735        return True  # i.e. True == initializers are part of graph input (ONNX IR v3)
736    val_keep_init_as_ip = (
737        True if keep_initializers_as_inputs is None else keep_initializers_as_inputs
738    )
739    if (
740        keep_initializers_as_inputs is None
741        and operator_export_type is _C_onnx.OperatorExportTypes.ONNX
742    ):
743        val_keep_init_as_ip = False
744    return val_keep_init_as_ip
745
746
747def _decide_add_node_names(add_node_names, operator_export_type):
748    return _resolve_args_by_export_type(
749        "add_node_names", add_node_names, operator_export_type
750    )
751
752
753def _decide_constant_folding(do_constant_folding, operator_export_type, training):
754    do_constant_folding = _resolve_args_by_export_type(
755        "do_constant_folding", do_constant_folding, operator_export_type
756    )
757    if do_constant_folding and (
758        training is not None and training is not _C_onnx.TrainingMode.EVAL
759    ):
760        warnings.warn(
761            "It is recommended that constant folding be turned off ('do_constant_folding=False') "
762            "when exporting the model in training-amenable mode, i.e. with 'training=TrainingMode.TRAIN' "
763            "or 'training=TrainingMode.PRESERVE' (when model is in training mode). Otherwise, some "
764            "learnable model parameters may not translate correctly in the exported ONNX model "
765            "because constant folding mutates model parameters. Please consider "
766            "turning off constant folding or setting the training=TrainingMode.EVAL."
767        )
768    return do_constant_folding
769
770
771def _signature(model) -> inspect.Signature:
772    should_be_callable = getattr(model, "forward", model)
773    if callable(should_be_callable):
774        return inspect.signature(should_be_callable)
775    raise ValueError("model has no forward method and is not callable")
776
777
778def _decide_input_format(model, args):
779    try:
780        sig = _signature(model)
781    except ValueError as e:
782        warnings.warn(f"{e}, skipping _decide_input_format")
783        return args
784    try:
785        ordered_list_keys = list(sig.parameters.keys())
786        if ordered_list_keys[0] == "self":
787            ordered_list_keys = ordered_list_keys[1:]
788        args_dict: dict = {}
789        if isinstance(args, list):
790            args_list = args
791        elif isinstance(args, tuple):
792            args_list = list(args)
793        else:
794            args_list = [args]
795        if isinstance(args_list[-1], dict):
796            args_dict = args_list[-1]
797            args_list = args_list[:-1]
798        n_nonkeyword = len(args_list)
799        for optional_arg in ordered_list_keys[n_nonkeyword:]:
800            if optional_arg in args_dict:
801                args_list.append(args_dict[optional_arg])
802            # Check if this arg has a default value
803            else:
804                param = sig.parameters[optional_arg]
805                if param.default != param.empty:
806                    args_list.append(param.default)
807        args = args_list if isinstance(args, list) else tuple(args_list)
808    # Cases of models with no input args
809    except IndexError:
810        warnings.warn("No input args, skipping _decide_input_format")
811    except Exception as e:
812        warnings.warn(f"Skipping _decide_input_format\n {e.args[0]}")
813    return args
814
815
816def _from_dynamic_axes_to_dynamic_shapes(
817    model,
818    dynamic_axes: Mapping[str, Mapping[int, str]]
819    | Mapping[str, Sequence[int]]
820    | None = None,
821    input_names: Sequence[str] | None = None,
822) -> dict[str, Any] | None:
823    """
824
825    dynamic_axes examples:
826    (1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}}
827    (2) dynamic_axes = {"x": [0], "y": [1]}
828
829    these will be converted to dynamic_shapes respectively:
830    (1) dynamic_shapes = {"x": {0: Dim("my_custom_axis_name_1")}, "y": {1: Dim("my_custom_axis_name_2")}}
831    (2) dynamic_shapes = {"x": {0: Dim("x_dim_0")}, "y": {1: Dim("y_dim_1")}}  # auto-generated dim names
832
833    """
834    if dynamic_axes is None:
835        return None
836
837    if input_names is None:
838        input_names_set = set()
839    else:
840        input_names_set = set(input_names)
841
842    dynamic_shapes: dict[str, Any | None] = {}
843    for input_name, axes in dynamic_axes.items():
844        if input_name in input_names_set:
845            raise ValueError(
846                "Assinging new input names is not supported yet. Please use model forward signature "
847                "to specify input names in dynamix_axes."
848            )
849        if isinstance(axes, dict):
850            dynamic_shapes[input_name] = {
851                k: torch.export.Dim(v) for k, v in axes.items()
852            }
853        elif isinstance(axes, list):
854            dynamic_shapes[input_name] = {
855                k: torch.export.Dim(f"{input_name}_dim_{k}") for k in axes
856            }
857        else:
858            raise TypeError(
859                f"dynamic_axes value must be either a dict or a list, but got {type(axes)}"
860            )
861    # torch.export.export needs static dim to present in dynamic_shapes
862    # for all input tensors, so we need to add them with None
863    try:
864        sig = _signature(model)
865    except ValueError as e:
866        warnings.warn(f"{e}, skipping auto filling None on static axes...")
867        return dynamic_shapes
868    for input_name in sig.parameters.keys():
869        if input_name not in dynamic_shapes:
870            dynamic_shapes[input_name] = None
871    return dynamic_shapes
872
873
874def _trace(func, args, operator_export_type, return_outs=False):
875    # Special case for common case of passing a single Tensor
876    if isinstance(args, torch.Tensor):
877        args = (args,)
878
879    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
880        func,
881        args,
882        strict=False,
883        _force_outplace=False,
884        _return_inputs_states=True,
885    )
886    warn_on_static_input_change(inputs_states)
887
888    trace_graph = _optimize_graph(trace_graph, operator_export_type, params_dict={})
889    if return_outs:
890        return trace_graph, torch_out
891    return trace_graph
892
893
894def _trace_and_get_graph_from_model(model, args):
895    # A basic sanity check: make sure the state_dict keys are the same
896    # before and after running the model.  Fail fast!
897    orig_state_dict_keys = torch.jit._unique_state_dict(model).keys()
898
899    # Disable Autocast cache because it replaces kernel's weight and bias
900    # by (undesired) constants.
901    # No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665
902    prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
903    torch.set_autocast_cache_enabled(False)
904    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
905        model,
906        args,
907        strict=False,
908        _force_outplace=False,
909        _return_inputs_states=True,
910    )
911    torch.set_autocast_cache_enabled(prev_autocast_cache_enabled)
912
913    warn_on_static_input_change(inputs_states)
914
915    if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys():
916        raise RuntimeError(
917            "state_dict changed after running the tracer; "
918            "something weird is happening in your model!"
919        )
920
921    return trace_graph, torch_out
922
923
924def _get_param_count_list(method_graph, args_params):
925    param_count_list = []
926    for input_, arg_params_ in zip(method_graph.inputs(), args_params):
927        if "PackedParams" in str(input_.type()):
928            in_vars, _ = torch.jit._flatten(arg_params_)
929            param_count_list.append(len(in_vars))
930        else:
931            param_count_list.append(arg_params_ is not None)
932
933    return param_count_list
934
935
936def _check_flatten_did_not_remove(original, jit_flattened):
937    """torch.jit._flatten removes None. Check if it did so in this case."""
938
939    def flatten(x):
940        if isinstance(x, (list, tuple)):
941            for inner in x:
942                yield from flatten(inner)
943        elif isinstance(x, dict):
944            for inner in x.values():
945                yield from flatten(inner)
946        else:
947            yield x
948
949    flattened_with_none = list(flatten(original))
950    num_none = len(flattened_with_none) - len(jit_flattened)
951    assert num_none >= 0
952    if num_none:
953        raise ValueError(
954            f"args contained {num_none} None's after flattening. "
955            "When exporting a ScriptModule or ScriptFunction, no args may "
956            "be None because that breaks type propagation."
957        )
958
959
960def _create_jit_graph(
961    model: torch.nn.Module | torch.jit.ScriptFunction, args: Sequence[Any]
962) -> tuple[_C.Graph, list[_C.IValue], Any | None, _C.ScriptModule | None]:
963    if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)):
964        flattened_args = tuple(torch.jit._flatten(tuple(args))[0])
965        _check_flatten_did_not_remove(args, flattened_args)
966        torch_out = None
967
968        if isinstance(model, torch.jit.ScriptModule):
969            try:
970                graph = model.forward.graph  # type: ignore[attr-defined]
971            except AttributeError as e:
972                raise RuntimeError("'forward' method must be a script method") from e
973            _C._jit_pass_onnx_function_substitution(graph)
974            freezed_module = _C._freeze_module(
975                cast(_C.ScriptModule, model._c), preserveParameters=True
976            )
977            module, params = _C._jit_onnx_list_model_parameters(freezed_module)
978            method_graph = module._get_method("forward").graph
979            args_params = tuple(args) + tuple(params)
980            param_count_list = _get_param_count_list(method_graph, args_params)
981            in_vars, _ = torch.jit._flatten(args_params)
982            graph = _C._propagate_and_assign_input_shapes(
983                method_graph, tuple(in_vars), param_count_list, False, False
984            )
985            return graph, params, torch_out, module
986
987        # torch.jit.ScriptFunction
988        params = []
989        graph = model.graph
990        _C._jit_pass_onnx_function_substitution(graph)
991        param_count_list = _get_param_count_list(graph, args)
992        graph = _C._propagate_and_assign_input_shapes(
993            graph, flattened_args, param_count_list, False, False
994        )
995        return graph, params, torch_out, None
996
997    graph, torch_out = _trace_and_get_graph_from_model(model, args)
998    _C._jit_pass_onnx_lint(graph)
999    state_dict = torch.jit._unique_state_dict(model)
1000    params = list(state_dict.values())
1001    graph_inputs = list(graph.inputs())
1002    user_input_num = len(graph_inputs) - len(state_dict)
1003    param_names = list(state_dict.keys())
1004    for i, inp in enumerate(graph_inputs):
1005        if i >= user_input_num:
1006            inp.setDebugName(param_names[i - user_input_num])
1007    _C._jit_pass_onnx_function_substitution(graph)
1008    return graph, params, torch_out, None
1009
1010
1011def _get_named_param_dict(graph, params):
1012    input_and_param_names = [val.debugName() for val in graph.inputs()]
1013    param_names = input_and_param_names[len(input_and_param_names) - len(params) :]
1014    _params_dict = dict(zip(param_names, params))
1015    return _params_dict
1016
1017
1018def _get_example_outputs(model, args):
1019    input_args = copy.deepcopy(args)
1020    input_kwargs = {}
1021    if input_args and isinstance(input_args[-1], dict):
1022        input_kwargs = input_args[-1]
1023        input_args = input_args[:-1]
1024
1025    example_outputs = model(*input_args, **input_kwargs)
1026    if isinstance(example_outputs, list):
1027        example_outputs = [example_outputs]
1028    elif not isinstance(example_outputs, tuple):
1029        example_outputs = (example_outputs,)
1030
1031    return example_outputs
1032
1033
1034_qtype_vtype_map = {
1035    torch.quint8: torch.uint8,
1036    torch.qint8: torch.int8,
1037    torch.qint32: torch.int32,
1038    torch.quint4x2: torch.int8,
1039}
1040
1041
1042def unpack_quantized_tensor(value, cast_onnx_accepted=True):
1043    if isinstance(value, torch.Tensor) and value.dtype in _qtype_vtype_map:
1044        q_value_dequantize = value.dequantize()
1045        q_scale = (
1046            torch.tensor(value.q_scale(), dtype=torch.double)
1047            if cast_onnx_accepted
1048            else torch.tensor(value.q_scale(), dtype=torch.float32)
1049        )
1050        q_zero_point = (
1051            torch.tensor(value.q_zero_point(), dtype=torch.int64)
1052            if cast_onnx_accepted
1053            else torch.tensor(value.q_zero_point(), dtype=_qtype_vtype_map[value.dtype])
1054        )
1055        q_value = q_value_dequantize / q_scale + q_zero_point
1056        q_value = q_value.to(dtype=_qtype_vtype_map[value.dtype])
1057        return q_value, q_scale, q_zero_point
1058    else:
1059        return (value,)
1060
1061
1062def _pre_trace_quant_model(model, args):
1063    r"""Returns `torch.jit.trace(model, args)` if model is quantized. Otherwise do nothing and return
1064    original model.
1065
1066    This is due to https://github.com/pytorch/pytorch/issues/75761.
1067    """
1068    if any(
1069        hasattr(m, "_packed_params") for m in getattr(model, "modules", list)()
1070    ) or any(getattr(arg, "is_quantized", False) for arg in args):
1071        return torch.jit.trace(model, args)
1072    return model
1073
1074
1075def _model_to_graph(
1076    model,
1077    args,
1078    verbose=False,
1079    input_names=None,
1080    output_names=None,
1081    operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
1082    do_constant_folding=True,
1083    _disable_torch_constant_prop=False,
1084    fixed_batch_size=False,
1085    training=_C_onnx.TrainingMode.EVAL,
1086    dynamic_axes=None,
1087) -> tuple[
1088    _C.Graph,
1089    dict[str, torch.Tensor],
1090    torch.Tensor
1091    | tuple[torch.Tensor, ...]
1092    | list[torch.Tensor]
1093    | dict[str, torch.Tensor]
1094    | Any
1095    | None,
1096]:
1097    """Converts model into an ONNX graph.
1098
1099    Returns:
1100        graph: A TorchScript IR Graph with ONNX nodes.
1101        params_dict: Dict from input param name to param value.
1102        torch_out: The output tensors resulting from the trace of ``model``.
1103            If ``model`` is a :class:`torch.jit.ScriptModule` or :class:`torch.jit.ScriptFunction`,
1104            this will be None, since we are not doing any tracing.
1105    """
1106    # TODO: can we simplify this to always return a tuple of Tensor or None?
1107
1108    # Special case for common case of passing a single Tensor
1109    if isinstance(args, (torch.Tensor, int, float, bool)):
1110        args = (args,)
1111
1112    model = _pre_trace_quant_model(model, args)
1113    graph, params, torch_out, module = _create_jit_graph(model, args)
1114    params_dict = _get_named_param_dict(graph, params)
1115
1116    try:
1117        graph = _optimize_graph(
1118            graph,
1119            operator_export_type,
1120            _disable_torch_constant_prop=_disable_torch_constant_prop,
1121            fixed_batch_size=fixed_batch_size,
1122            params_dict=params_dict,
1123            dynamic_axes=dynamic_axes,
1124            input_names=input_names,
1125            module=module,
1126        )
1127    except Exception as e:
1128        torch.onnx.log("Torch IR graph at exception: ", graph)
1129        raise
1130
1131    is_script = isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule))
1132    if is_script:
1133        example_outputs = _get_example_outputs(model, args)
1134        example_outputs_final = ()
1135        for example_output in example_outputs:
1136            example_outputs_final += unpack_quantized_tensor(example_output)
1137        out_vars, desc = torch.jit._flatten(example_outputs_final)
1138        _C._jit_pass_onnx_assign_output_shape(
1139            graph,
1140            out_vars,
1141            desc,
1142            GLOBALS.onnx_shape_inference,
1143            is_script,
1144            GLOBALS.export_onnx_opset_version,
1145        )
1146
1147    # NB: ONNX requires complete information about output types, which might be
1148    # erased by some optimizations, so we need to set it explicitly again.
1149    else:
1150        if not isinstance(torch_out, (list, tuple)):
1151            output_wrapped = [torch_out]
1152        else:
1153            output_wrapped = torch_out  # type: ignore[assignment]
1154
1155        output_tensors, out_desc = torch.jit._flatten(tuple(output_wrapped))
1156        # assign_output_shape pass is not compatible with quantized outputs.
1157        # Quantized outputs are flattened to 3 values in ONNX, while packed as
1158        # single value in PyTorch.
1159        if not any(getattr(out, "is_quantized", False) for out in output_tensors):
1160            _C._jit_pass_onnx_assign_output_shape(
1161                graph,
1162                output_tensors,
1163                out_desc,
1164                GLOBALS.onnx_shape_inference,
1165                is_script,
1166                GLOBALS.export_onnx_opset_version,
1167            )
1168
1169    _set_input_and_output_names(graph, input_names, output_names)
1170    params_dict = _get_named_param_dict(graph, params)
1171
1172    if (
1173        do_constant_folding
1174        and GLOBALS.export_onnx_opset_version
1175        >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET
1176    ):
1177        if training is None or training == _C_onnx.TrainingMode.EVAL:
1178            params_dict = _C._jit_pass_onnx_eval_peephole(graph, params_dict)
1179
1180        params_dict = _C._jit_pass_onnx_constant_fold(
1181            graph, params_dict, GLOBALS.export_onnx_opset_version
1182        )
1183        _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
1184
1185    if GLOBALS.onnx_shape_inference:
1186        _C._jit_pass_onnx_graph_shape_type_inference(
1187            graph, params_dict, GLOBALS.export_onnx_opset_version
1188        )
1189
1190    params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
1191
1192    # For ONNX opset < 9, constants only have three data types: float16, float, double.
1193    # In this pass transform constants of other data types to float/double + cast operator.
1194    if GLOBALS.export_onnx_opset_version < 9:
1195        _C._jit_pass_onnx_cast_all_constant_to_floating(graph)
1196
1197    params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict)
1198    _C._jit_decay_packed_param_input_types(graph)
1199
1200    # If output names lack a proper name and are identified only by their unique
1201    # give them a legible name for debugging purposes
1202    _apply_friendly_debug_names(graph, params_dict)
1203
1204    return graph, params_dict, torch_out
1205
1206
1207@torch._disable_dynamo
1208@_deprecation.deprecated("2.5", "the future", "use onnx.printer.to_text() instead")
1209def export_to_pretty_string(
1210    model,
1211    args,
1212    export_params=True,
1213    verbose=False,
1214    training=_C_onnx.TrainingMode.EVAL,
1215    input_names=None,
1216    output_names=None,
1217    operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
1218    export_type=None,
1219    google_printer=False,
1220    opset_version=None,
1221    keep_initializers_as_inputs=None,
1222    custom_opsets=None,
1223    add_node_names=True,
1224    do_constant_folding=True,
1225    dynamic_axes=None,
1226):
1227    """Similar to :func:`export`, but returns a text representation of the ONNX model.
1228
1229    Only differences in args listed below. All other args are the same
1230    as :func:`export`.
1231
1232    Args:
1233        add_node_names (bool, default True): Whether or not to set
1234            NodeProto.name. This makes no difference unless
1235            ``google_printer=True``.
1236        google_printer (bool, default False): If False, will return a custom,
1237            compact representation of the model. If True will return the
1238            protobuf's `Message::DebugString()`, which is more verbose.
1239
1240    Returns:
1241        A UTF-8 str containing a human-readable representation of the ONNX model.
1242    """
1243    if opset_version is None:
1244        opset_version = _constants.ONNX_DEFAULT_OPSET
1245    if custom_opsets is None:
1246        custom_opsets = {}
1247    GLOBALS.export_onnx_opset_version = opset_version
1248    GLOBALS.operator_export_type = operator_export_type
1249
1250    with exporter_context(model, training, verbose):
1251        val_keep_init_as_ip = _decide_keep_init_as_input(
1252            keep_initializers_as_inputs, operator_export_type, opset_version
1253        )
1254        val_add_node_names = _decide_add_node_names(
1255            add_node_names, operator_export_type
1256        )
1257        val_do_constant_folding = _decide_constant_folding(
1258            do_constant_folding, operator_export_type, training
1259        )
1260        args = _decide_input_format(model, args)
1261        graph, params_dict, torch_out = _model_to_graph(
1262            model,
1263            args,
1264            verbose,
1265            input_names,
1266            output_names,
1267            operator_export_type,
1268            val_do_constant_folding,
1269            training=training,
1270            dynamic_axes=dynamic_axes,
1271        )
1272
1273        return graph._pretty_print_onnx(  # type: ignore[attr-defined]
1274            params_dict,
1275            opset_version,
1276            False,
1277            operator_export_type,
1278            google_printer,
1279            val_keep_init_as_ip,
1280            custom_opsets,
1281            val_add_node_names,
1282        )
1283
1284
1285@_deprecation.deprecated("2.5", "the future", "avoid using this function")
1286def unconvertible_ops(
1287    model,
1288    args,
1289    training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
1290    opset_version: int | None = None,
1291) -> tuple[_C.Graph, list[str]]:
1292    """Returns an approximated list of all ops that are yet supported by :mod:`torch.onnx`.
1293
1294    The list is approximated because some ops may be removed during the conversion
1295    process and don't need to be converted. Some other ops may have partial support
1296    that will fail conversion with particular inputs. Please open a Github Issue
1297    for op support requests.
1298
1299    Args:
1300        model: Same as the `model` parameter in :func:`torch.onnx.export`.
1301        args: Same as the `args` parameter in :func:`torch.onnx.export`.
1302        training: Same as the `training` parameter in :func:`torch.onnx.export`.
1303        opset_version: Same as the `opset_version` parameter in :func:`torch.onnx.export`.
1304
1305    Returns:
1306        The JIT graph and a list of unconvertible ops in the format of "domain::op".
1307    """
1308
1309    opset_version = opset_version or _constants.ONNX_DEFAULT_OPSET
1310    GLOBALS.export_onnx_opset_version = opset_version
1311
1312    try:
1313        with exporter_context(model, training, verbose=False):
1314            # Create a mostly clean JIT graph that contains the plain aten and
1315            # other ops we can check with the symbolic registry.
1316            # NOTE: We don't want to actually convert any ops to ONNX or run any
1317            # symbolic functions because there is a higher chance that a pass
1318            # fails or an unconvertible op messes up the graph during ONNX conversion.
1319            # This way we can always generate a list just by looking at the names
1320            # of the ops in the graph.
1321            args = _decide_input_format(model, args)
1322            model = _pre_trace_quant_model(model, args)
1323            graph, _, _, module = _create_jit_graph(model, args)
1324            _C._jit_pass_inline(graph)
1325            _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module)
1326            _C._jit_pass_erase_number_types(graph)
1327            _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
1328    except Exception as e:
1329        raise errors.OnnxExporterError(
1330            "Failed to discover unconvertible ops because of errors during the JIT graph "
1331            "generation process."
1332        ) from e
1333
1334    unsupported_ops = []
1335    for node in graph.nodes():
1336        domain_op = node.kind()
1337        if domain_op.startswith(("onnx::", "prim::")):
1338            # We consider onnx and prim ops as supported ops, even though some "prim"
1339            # ops are not implemented as symbolic functions, because they may be
1340            # eliminated in the conversion passes. Users may still see errors caused
1341            # by prim ops even though they don't show up in the list.
1342            continue
1343        if not registration.registry.is_registered_op(
1344            domain_op.rstrip("_"), opset_version
1345        ):
1346            # We consider all registered ops supported, even though some of them are
1347            # only partially supported, because there is not yet a good way to check
1348            # if an op is fully supported.
1349            # TODO(justinchuby): Create a way to check if an op is fully supported.
1350            unsupported_ops.append(domain_op)
1351    return graph, unsupported_ops
1352
1353
1354def _setup_trace_module_map(
1355    model: torch.nn.Module | torch.jit.ScriptModule,
1356    export_modules_as_functions: bool | Collection[type[torch.nn.Module]],
1357) -> set[str]:
1358    def __register_attribute_hook():
1359        attr_name = "_onnx_attrs"
1360
1361        def _track_module_attributes_forward_pre_hook(module, input):
1362            setattr(module, attr_name, _get_module_attributes(module))
1363
1364        def _track_module_attributes_forward_hook(module, input, output):
1365            tracing_state = _C._get_tracing_state()
1366            if not tracing_state:
1367                return
1368
1369            graph = tracing_state.graph()
1370            onnx_attrs = {}
1371            if hasattr(module, attr_name):
1372                onnx_attrs = getattr(module, attr_name)
1373                delattr(module, attr_name)
1374
1375            _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs)
1376
1377        for m in model.modules():
1378            m.register_forward_hook(_track_module_attributes_forward_hook)
1379            m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook)
1380
1381    def _unqualified_variable_name(qualified_name: str) -> str:
1382        """
1383        Parse qualified variable name and return the unqualified version.
1384
1385        Pure numeric atoms are considered inadequate, so this function will look past them,
1386        and start from the first non-numeric atom.
1387
1388        Example:
1389            >>> _unqualified_variable_name("__main__.Foo.bar")
1390            'bar'
1391            >>> _unqualified_variable_name("__main__.Foo.bar.0")
1392            'bar.0'
1393        """
1394        name_atoms = qualified_name.split(".")
1395        for i, atom in reversed(list(enumerate(name_atoms))):
1396            if not atom.isnumeric():
1397                return ".".join(name_atoms[i:])
1398        return qualified_name
1399
1400    trace_module_map = {
1401        _m: torch._C._jit_onnx_create_full_scope_name(
1402            torch.typename(type(_m)), _unqualified_variable_name(_n)
1403        )
1404        for _n, _m in model.named_modules()
1405    }
1406    torch.jit._trace._trace_module_map = trace_module_map
1407    if isinstance(export_modules_as_functions, bool) and export_modules_as_functions:
1408        module_typenames = {torch.typename(type(module)) for module in trace_module_map}
1409    elif isinstance(export_modules_as_functions, set) and export_modules_as_functions:
1410
1411        def _find_typename(v):
1412            if isinstance(v, type):
1413                return torch.typename(v)
1414            else:
1415                raise RuntimeError(
1416                    "Only type of the `nn.Module` should be "
1417                    "passed in the set for argument `export_modules_as_functions`. "
1418                    f"Got `{type(v).__name__}`."
1419                )
1420
1421        module_typenames = {_find_typename(v) for v in export_modules_as_functions}
1422    else:
1423        module_typenames = set()
1424
1425    if module_typenames:
1426        __register_attribute_hook()
1427
1428    return module_typenames
1429
1430
1431def _reset_trace_module_map():
1432    torch.jit._trace._trace_module_map = None
1433    _C._jit_pass_onnx_clear_scope_records()
1434
1435
1436def _get_module_attributes(module):
1437    annotations = typing.get_type_hints(type(module))
1438    base_m_annotations = typing.get_type_hints(torch.nn.Module)
1439    [annotations.pop(k, None) for k in base_m_annotations]
1440    # Check whether module attributes can be accessed. Some classes
1441    # define attributes but don't provide access to them in their
1442    # constructor.
1443    #
1444    # For example, torch.nn.Embedding has the `freeze` variable and its
1445    # type specified in the class but the attribute is not created in the
1446    # constructor. In other words, there is no `self.freeze = <True | False>`
1447    # in the constructor.
1448    #
1449    # Reference: https://github.com/pytorch/pytorch/blob/92de1d322223fb5584e384971b32c46b93bc2f4b/torch/nn/modules/sparse.py#L120
1450    attrs = {}
1451    for k in annotations:
1452        try:
1453            attrs[k] = getattr(module, k)
1454        except AttributeError:
1455            torch.onnx.log(f"Skipping module attribute '{k}'")
1456            continue
1457    return attrs
1458
1459
1460def _export(
1461    model,
1462    args,
1463    f,
1464    export_params=True,
1465    verbose=False,
1466    training=_C_onnx.TrainingMode.EVAL,
1467    input_names=None,
1468    output_names=None,
1469    operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
1470    export_type=None,
1471    opset_version=None,
1472    do_constant_folding=True,
1473    dynamic_axes=None,
1474    keep_initializers_as_inputs=None,
1475    fixed_batch_size=False,
1476    custom_opsets=None,
1477    add_node_names=True,
1478    onnx_shape_inference=True,
1479    export_modules_as_functions: Any = False,
1480    autograd_inlining=True,
1481):
1482    assert GLOBALS.in_onnx_export is False
1483
1484    if export_type is None:
1485        export_type = _exporter_states.ExportTypes.PROTOBUF_FILE
1486
1487    if isinstance(model, torch.nn.DataParallel):
1488        raise ValueError(
1489            "torch.nn.DataParallel is not supported by ONNX "
1490            "exporter, please use 'attribute' module to "
1491            "unwrap model from torch.nn.DataParallel. Try "
1492            "torch.onnx.export(model.module, ...)"
1493        )
1494
1495    GLOBALS.onnx_shape_inference = onnx_shape_inference
1496
1497    if opset_version is None:
1498        opset_version = _constants.ONNX_DEFAULT_OPSET
1499
1500    # torch.onnx.export does not support opset versions >=18
1501    if opset_version > _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET:
1502        # We do not want to fail because we should still allow users to create
1503        # custom symbolic functions for opset>17
1504        warnings.warn(
1505            f"Exporting to ONNX opset version {opset_version} is not supported. "
1506            f"by 'torch.onnx.export()'. "
1507            f"The highest opset version supported is {_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET}. "
1508            f"To use a newer opset version, consider 'torch.onnx.export(..., dynamo=True)'. ",
1509            category=errors.OnnxExporterWarning,
1510        )
1511
1512    if export_modules_as_functions and opset_version < 15:
1513        raise ValueError(
1514            "`export_modules_as_functions` is not supported for `opset_version` < 15."
1515            "This is because `opset_version` < 15 implies IR version < 8, which means "
1516            "no local function support. "
1517        )
1518    if not operator_export_type:
1519        operator_export_type = _C_onnx.OperatorExportTypes.ONNX
1520
1521    # By default, training=TrainingMode.EVAL,
1522    # which is good because running a model in training mode could result in
1523    # internal buffers getting updated, dropout getting applied, etc.
1524    # If you really know what you're doing, you can turn
1525    # training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE,
1526    # (to preserve whatever the original training mode was.)
1527    GLOBALS.export_onnx_opset_version = opset_version
1528    GLOBALS.operator_export_type = operator_export_type
1529
1530    try:
1531        GLOBALS.in_onnx_export = True
1532        _autograd_inlining_previous = GLOBALS.autograd_inlining
1533        GLOBALS.autograd_inlining = autograd_inlining
1534
1535        module_typenames_to_export_as_functions: set[str] = set()
1536        if isinstance(model, (torch.nn.Module, torch.jit.ScriptModule)):
1537            module_typenames_to_export_as_functions = _setup_trace_module_map(
1538                model, export_modules_as_functions
1539            )
1540
1541        with exporter_context(model, training, verbose):
1542            val_keep_init_as_ip = _decide_keep_init_as_input(
1543                keep_initializers_as_inputs,
1544                operator_export_type,
1545                opset_version,
1546            )
1547            val_add_node_names = _decide_add_node_names(
1548                add_node_names, operator_export_type
1549            )
1550            val_do_constant_folding = _decide_constant_folding(
1551                do_constant_folding, operator_export_type, training
1552            )
1553            # Normally f can be a file-like object, but for large models, the external data format requires a
1554            # valid `model_file_location`. Code in export.cpp will enforce this.
1555            if isinstance(f, str):
1556                model_file_location = f
1557            else:
1558                model_file_location = ""
1559            args = _decide_input_format(model, args)
1560            if dynamic_axes is None:
1561                dynamic_axes = {}
1562            _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
1563
1564            graph, params_dict, torch_out = _model_to_graph(
1565                model,
1566                args,
1567                verbose,
1568                input_names,
1569                output_names,
1570                operator_export_type,
1571                val_do_constant_folding,
1572                fixed_batch_size=fixed_batch_size,
1573                training=training,
1574                dynamic_axes=dynamic_axes,
1575            )
1576
1577            # TODO: Don't allocate a in-memory string for the protobuf
1578            defer_weight_export = (
1579                export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
1580            )
1581            if custom_opsets is None:
1582                custom_opsets = {}
1583
1584            _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
1585            node_attr_to_name = {}  # type: ignore[var-annotated]
1586            if module_typenames_to_export_as_functions:
1587                # NOTE: cannot call DCE after this pass. DCE will remove function definition nodes.
1588                node_attr_to_name = _C._jit_pass_onnx_function_extraction(
1589                    graph,
1590                    module_typenames_to_export_as_functions,
1591                    list(params_dict.keys()),
1592                )
1593
1594            if keep_initializers_as_inputs is not True:
1595                params_dict = _C._jit_pass_onnx_deduplicate_initializers(  # type: ignore[assignment]
1596                    graph,
1597                    params_dict,  # type: ignore[arg-type]
1598                    getattr(model, "training", False),  # type: ignore[arg-type]
1599                )
1600            _C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph)
1601            if export_params:
1602                (
1603                    proto,
1604                    export_map,
1605                    val_use_external_data_format,
1606                    node_names,
1607                ) = graph._export_onnx(  # type: ignore[attr-defined]
1608                    params_dict,
1609                    opset_version,
1610                    dynamic_axes,
1611                    defer_weight_export,
1612                    operator_export_type,
1613                    not verbose,
1614                    val_keep_init_as_ip,
1615                    custom_opsets,
1616                    val_add_node_names,
1617                    model_file_location,
1618                    node_attr_to_name,
1619                )
1620            else:
1621                (
1622                    proto,
1623                    export_map,
1624                    val_use_external_data_format,
1625                    node_names,
1626                ) = graph._export_onnx(  # type: ignore[attr-defined]
1627                    {},
1628                    opset_version,
1629                    dynamic_axes,
1630                    False,
1631                    operator_export_type,
1632                    not verbose,
1633                    val_keep_init_as_ip,
1634                    custom_opsets,
1635                    val_add_node_names,
1636                    model_file_location,
1637                    node_attr_to_name,
1638                )
1639            # insert function_proto into model_proto.
1640            proto = onnx_proto_utils._add_onnxscript_fn(
1641                proto,
1642                custom_opsets,
1643            )
1644            if verbose:
1645                torch.onnx.log("Exported graph: ", graph)
1646            onnx_proto_utils._export_file(proto, f, export_type, export_map)
1647    finally:
1648        assert GLOBALS.in_onnx_export
1649        GLOBALS.in_onnx_export = False
1650        GLOBALS.autograd_inlining = _autograd_inlining_previous
1651        _reset_trace_module_map()
1652
1653    return torch_out
1654
1655
1656def _apply_friendly_debug_names(graph, params):
1657    for n in graph.nodes():
1658        for v in n.inputs():
1659            old_name = v.debugName()
1660            if old_name != str(v.unique()):
1661                continue
1662            new_name = f"{n.kind()}_{v.unique()}"
1663            v.setDebugName(new_name)
1664            if old_name in params:
1665                params[new_name] = params.pop(old_name)
1666
1667
1668def _set_input_and_output_names(graph, input_names, output_names):
1669    def set_names(node_list, name_list, descriptor):
1670        if name_list is None:
1671            return
1672        if len(name_list) > len(node_list):
1673            raise RuntimeError(
1674                "number of %s names provided (%d) exceeded number of %ss (%d)"
1675                % (descriptor, len(name_list), descriptor, len(node_list))
1676            )
1677
1678        # Mark if the output node DebugName is set before.
1679        output_node_set = set()
1680        for i, (name, node) in enumerate(zip(name_list, node_list)):
1681            # Duplicated output node, insert onnx::Identity to avoid setting the same DebugName after setDebugName().
1682            if descriptor == "output":
1683                if node in output_node_set:
1684                    identity_node = graph.create("onnx::Identity")
1685                    identity_node.insertAfter(node.node())
1686                    identity_node.addInput(node)
1687                    identity_node.output().setType(node.type())
1688                    graph.return_node().replaceInput(i, identity_node.output())
1689                    node = identity_node.output()
1690                output_node_set.add(node)
1691
1692            if node.debugName() != name:
1693                node.setDebugName(name)
1694
1695    set_names(list(graph.inputs()), input_names, "input")
1696    set_names(list(graph.outputs()), output_names, "output")
1697
1698
1699def _run_symbolic_method(g, op_name, symbolic_fn, args):
1700    r"""
1701    This trampoline function gets invoked for every symbolic method
1702    call from C++.
1703    """
1704    try:
1705        graph_context = jit_utils.GraphContext(
1706            graph=g,
1707            block=g.block(),
1708            opset=GLOBALS.export_onnx_opset_version,
1709            original_node=None,  # type: ignore[arg-type]
1710            params_dict=_params_dict,
1711            env={},
1712            values_in_env=set(),
1713            new_nodes=[],
1714        )
1715        return symbolic_fn(graph_context, *args)
1716    except TypeError as e:
1717        # Handle the specific case where we didn't successfully dispatch
1718        # to symbolic_fn.  Otherwise, the backtrace will have the clues
1719        # you need.
1720        e.args = (f"{e.args[0]} (occurred when translating {op_name})",)
1721        raise
1722
1723
1724def _add_block(node: _C.Node) -> _C.Block:
1725    return node.addBlock()
1726
1727
1728def _add_input_to_block(block: _C.Block):
1729    return block.addInputToBlock()  # type: ignore[attr-defined]
1730
1731
1732def _add_output_to_block(block: _C.Block, value: _C.Value) -> int:
1733    return block.registerOutput(value)
1734
1735
1736def _should_aten_fallback(
1737    name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes
1738):
1739    # For all builds, if domain=="aten" and operator_export_type==ONNX_ATEN,
1740    #   an aten::ATen operator is created regardless of symbolics existence
1741
1742    is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version)
1743    is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN
1744    is_aten_fallback_export = (
1745        operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
1746    )
1747
1748    if not name.startswith("aten::"):
1749        return False
1750
1751    if is_onnx_aten_export or (is_aten_fallback_export and not is_exportable_aten_op):
1752        return True
1753
1754    return False
1755
1756
1757def _get_aten_op_overload_name(n: _C.Node) -> str:
1758    # Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
1759    schema = n.schema()
1760    if not schema.startswith("aten::"):
1761        return ""
1762    return _C.parse_schema(schema).overload_name
1763
1764
1765def _run_symbolic_function(
1766    graph: _C.Graph,
1767    block: _C.Block,
1768    node: _C.Node,
1769    inputs: Any,
1770    env: dict[_C.Value, _C.Value],
1771    values_in_env: set[_C.Value],
1772    new_nodes: list[_C.Node],
1773    operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
1774) -> _C.Value | Sequence[_C.Value | None] | None:
1775    """Runs a symbolic function.
1776
1777    The function is used in C++ to export the node to ONNX.
1778
1779    Returns:
1780        A single or a tuple of Values.
1781        None when the node gets cloned as is into the new graph.
1782    """
1783
1784    opset_version = GLOBALS.export_onnx_opset_version
1785
1786    # See Note [Export inplace]
1787    node_kind = node.kind()
1788    if node_kind.endswith("_"):
1789        # Treat relu_ -> relu; add_ -> add etc.
1790        ns_op_name = node_kind[:-1]
1791    else:
1792        ns_op_name = node_kind
1793
1794    namespace, op_name = jit_utils.parse_node_kind(ns_op_name)
1795
1796    graph_context = jit_utils.GraphContext(
1797        graph=graph,
1798        block=block,
1799        opset=opset_version,
1800        original_node=node,
1801        params_dict=_params_dict,
1802        env=env,
1803        values_in_env=values_in_env,
1804        new_nodes=new_nodes,
1805    )
1806
1807    # Direct ATen export requested
1808    if _should_aten_fallback(ns_op_name, opset_version, operator_export_type):
1809        attrs = {
1810            k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
1811            for k in node.attributeNames()
1812        }
1813        outputs = node.outputsSize()
1814        attrs["outputs"] = outputs
1815        return graph_context.aten_op(
1816            op_name,
1817            *inputs,
1818            overload_name=_get_aten_op_overload_name(node),
1819            **attrs,
1820        )
1821
1822    try:
1823        domain = namespace
1824        symbolic_function_name = f"{domain}::{op_name}"
1825
1826        symbolic_function_group = registration.registry.get_function_group(
1827            symbolic_function_name
1828        )
1829        if symbolic_function_group is not None:
1830            symbolic_fn = symbolic_function_group.get(opset_version)
1831            if symbolic_fn is not None:
1832                # TODO Wrap almost identical attrs assignment or comment the difference.
1833                attrs = {
1834                    k: symbolic_helper._node_get(node, k) for k in node.attributeNames()
1835                }
1836                return symbolic_fn(graph_context, *inputs, **attrs)
1837
1838        attrs = {
1839            k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
1840            for k in node.attributeNames()
1841        }
1842        if namespace == "onnx":
1843            # Clone node to trigger ONNX shape inference
1844            return graph_context.op(
1845                op_name, *inputs, **attrs, outputs=node.outputsSize()
1846            )  # type: ignore[attr-defined]
1847
1848        raise errors.UnsupportedOperatorError(
1849            symbolic_function_name,
1850            opset_version,
1851            symbolic_function_group.get_min_supported()
1852            if symbolic_function_group
1853            else None,
1854        )
1855
1856    except RuntimeError:
1857        if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
1858            return None
1859        elif operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
1860            # Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK`
1861            attrs = {
1862                k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
1863                for k in node.attributeNames()
1864            }
1865            return graph_context.aten_op(
1866                op_name,
1867                *inputs,
1868                overload_name=_get_aten_op_overload_name(node),
1869                **attrs,
1870            )
1871        raise
1872    except TypeError as e:
1873        # Handle the specific case where we didn't successfully dispatch.
1874        # Otherwise, the backtrace will have the clues you need.
1875        e.args = (f"{e.args[0]} \n(Occurred when translating {op_name}).",)
1876        raise
1877
1878
1879def _verify_custom_op_name(symbolic_name: str):
1880    if not re.match(r"^[a-zA-Z0-9-_]+::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name):
1881        raise errors.OnnxExporterError(
1882            f"Failed to register operator {symbolic_name}. "
1883            "The symbolic name must match the format domain::name, "
1884            "and should start with a letter and contain only "
1885            "alphanumerical characters"
1886        )
1887
1888    ns, _ = jit_utils.parse_node_kind(symbolic_name)
1889    if ns == "onnx":
1890        raise ValueError(
1891            f"Failed to register operator {symbolic_name}. {ns} domain cannot be modified."
1892        )
1893
1894
1895def register_custom_op_symbolic(
1896    symbolic_name: str,
1897    symbolic_fn: Callable,
1898    opset_version: int,
1899):
1900    """Registers a symbolic function for a custom operator.
1901
1902    When the user registers symbolic for custom/contrib ops,
1903    it is highly recommended to add shape inference for that operator via setType API,
1904    otherwise the exported graph may have incorrect shape inference in some extreme cases.
1905    An example of setType is `test_aten_embedding_2` in `test_operators.py`.
1906
1907    See "Custom Operators" in the module documentation for an example usage.
1908
1909    Args:
1910        symbolic_name (str): The name of the custom operator in "<domain>::<op>"
1911            format.
1912        symbolic_fn (Callable): A function that takes in the ONNX graph and
1913            the input arguments to the current operator, and returns new
1914            operator nodes to add to the graph.
1915        opset_version (int): The ONNX opset version in which to register.
1916    """
1917    if symbolic_name.startswith("::"):
1918        symbolic_name = f"aten{symbolic_name}"
1919
1920    _verify_custom_op_name(symbolic_name)
1921
1922    registration.custom_onnx_symbolic(symbolic_name, opset_version)(symbolic_fn)
1923
1924
1925def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int):
1926    """Unregisters ``symbolic_name``.
1927
1928    See "Custom Operators" in the module documentation for an example usage.
1929
1930    Args:
1931        symbolic_name (str): The name of the custom operator in "<domain>::<op>"
1932            format.
1933        opset_version (int): The ONNX opset version in which to unregister.
1934    """
1935    if symbolic_name.startswith("::"):
1936        symbolic_name = f"aten{symbolic_name}"
1937
1938    _verify_custom_op_name(symbolic_name)
1939
1940    registration.registry.unregister(symbolic_name, opset_version)
1941
1942
1943def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
1944    """Ensures dynamic axes argument is follows the expected format."""
1945    if len(dynamic_axes) == 0:
1946        return
1947
1948    if hasattr(model, "graph"):
1949        # Extracting set of valid input/output names that shall be used for dynamic_axes
1950        if (input_names is None) or len(input_names) == 0:
1951            input_names = [x.debugName() for x in model.graph.inputs()]
1952        if (output_names is None) or len(output_names) == 0:
1953            output_names = [y.debugName() for y in model.graph.outputs()]
1954
1955    valid_names = set((input_names or []) + (output_names or []))
1956
1957    # If dynamic axes are provided as a list rather than dictionary, they should
1958    # first get converted to a dictionary in expected format. If desired axes names
1959    # are not provided for dynamic axes, automatic names shall be generated for
1960    # provided dynamic axes of specified input/output
1961    for key, value in dynamic_axes.items():
1962        if key not in valid_names:
1963            warnings.warn(
1964                f"Provided key {key} for dynamic axes is not a valid input/output name"
1965            )
1966        if isinstance(value, list):
1967            warnings.warn(
1968                "No names were found for specified dynamic axes of provided input."
1969                f"Automatically generated names will be applied to each dynamic axes of input {key}"
1970            )
1971
1972            value_dict = {}
1973            for i, x in enumerate(value):
1974                if not isinstance(x, int):
1975                    raise ValueError(
1976                        "The type of axis index is expected to be an integer"
1977                    )
1978                if x in value_dict:
1979                    warnings.warn(
1980                        f"Duplicate dynamic axis index {x} was provided for input {key}."
1981                    )
1982                else:
1983                    value_dict[x] = str(key) + "_dynamic_axes_" + str(i + 1)
1984            dynamic_axes[key] = value_dict
1985
1986
1987def model_signature(model: torch.nn.Module | Callable) -> inspect.Signature:
1988    return inspect.signature(
1989        model.forward if isinstance(model, torch.nn.Module) else model
1990    )
1991