1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4 5__all__ = [ 6 # Modules 7 "symbolic_helper", 8 "utils", 9 "errors", 10 # All opsets 11 "symbolic_caffe2", 12 "symbolic_opset7", 13 "symbolic_opset8", 14 "symbolic_opset9", 15 "symbolic_opset10", 16 "symbolic_opset11", 17 "symbolic_opset12", 18 "symbolic_opset13", 19 "symbolic_opset14", 20 "symbolic_opset15", 21 "symbolic_opset16", 22 "symbolic_opset17", 23 "symbolic_opset18", 24 "symbolic_opset19", 25 "symbolic_opset20", 26 # Enums 27 "ExportTypes", 28 "OperatorExportTypes", 29 "TrainingMode", 30 "TensorProtoDataType", 31 "JitScalarType", 32 # Public functions 33 "export", 34 "export_to_pretty_string", 35 "is_in_onnx_export", 36 "select_model_mode_for_export", 37 "register_custom_op_symbolic", 38 "unregister_custom_op_symbolic", 39 "disable_log", 40 "enable_log", 41 # Base error 42 "OnnxExporterError", 43 # Dynamo Exporter 44 "DiagnosticOptions", 45 "ExportOptions", 46 "ONNXProgram", 47 "ONNXRuntimeOptions", 48 "OnnxRegistry", 49 "dynamo_export", 50 "enable_fake_mode", 51 # DORT / torch.compile 52 "is_onnxrt_backend_supported", 53] 54 55from typing import Any, Callable, Collection, Mapping, Sequence, TYPE_CHECKING 56 57import torch 58from torch import _C 59from torch._C import _onnx as _C_onnx 60from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode 61 62from ._exporter_states import ExportTypes 63from ._internal.onnxruntime import ( 64 is_onnxrt_backend_supported, 65 OrtBackend as _OrtBackend, 66 OrtBackendOptions as _OrtBackendOptions, 67 OrtExecutionProvider as _OrtExecutionProvider, 68) 69from ._type_utils import JitScalarType 70from .errors import OnnxExporterError 71from .utils import ( 72 _optimize_graph, 73 _run_symbolic_function, 74 _run_symbolic_method, 75 export_to_pretty_string, 76 is_in_onnx_export, 77 register_custom_op_symbolic, 78 select_model_mode_for_export, 79 unregister_custom_op_symbolic, 80) 81 82 83from . import ( # usort: skip. Keep the order instead of sorting lexicographically 84 errors, 85 symbolic_caffe2, 86 symbolic_helper, 87 symbolic_opset7, 88 symbolic_opset8, 89 symbolic_opset9, 90 symbolic_opset10, 91 symbolic_opset11, 92 symbolic_opset12, 93 symbolic_opset13, 94 symbolic_opset14, 95 symbolic_opset15, 96 symbolic_opset16, 97 symbolic_opset17, 98 symbolic_opset18, 99 symbolic_opset19, 100 symbolic_opset20, 101 utils, 102) 103 104 105from ._internal._exporter_legacy import ( # usort: skip. needs to be last to avoid circular import 106 DiagnosticOptions, 107 ExportOptions, 108 ONNXProgram, 109 ONNXRuntimeOptions, 110 OnnxRegistry, 111 enable_fake_mode, 112) 113 114 115if TYPE_CHECKING: 116 import os 117 118# Set namespace for exposed private names 119DiagnosticOptions.__module__ = "torch.onnx" 120ExportOptions.__module__ = "torch.onnx" 121ExportTypes.__module__ = "torch.onnx" 122JitScalarType.__module__ = "torch.onnx" 123ONNXProgram.__module__ = "torch.onnx" 124ONNXRuntimeOptions.__module__ = "torch.onnx" 125OnnxExporterError.__module__ = "torch.onnx" 126OnnxRegistry.__module__ = "torch.onnx" 127_OrtBackend.__module__ = "torch.onnx" 128_OrtBackendOptions.__module__ = "torch.onnx" 129_OrtExecutionProvider.__module__ = "torch.onnx" 130enable_fake_mode.__module__ = "torch.onnx" 131is_onnxrt_backend_supported.__module__ = "torch.onnx" 132 133producer_name = "pytorch" 134producer_version = _C_onnx.PRODUCER_VERSION 135 136 137def export( 138 model: torch.nn.Module 139 | torch.export.ExportedProgram 140 | torch.jit.ScriptModule 141 | torch.jit.ScriptFunction, 142 args: tuple[Any, ...] = (), 143 f: str | os.PathLike | None = None, 144 *, 145 kwargs: dict[str, Any] | None = None, 146 export_params: bool = True, 147 verbose: bool | None = None, 148 input_names: Sequence[str] | None = None, 149 output_names: Sequence[str] | None = None, 150 opset_version: int | None = None, 151 dynamic_axes: Mapping[str, Mapping[int, str]] 152 | Mapping[str, Sequence[int]] 153 | None = None, 154 keep_initializers_as_inputs: bool = False, 155 dynamo: bool = False, 156 # Dynamo only options 157 external_data: bool = True, 158 dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, 159 report: bool = False, 160 verify: bool = False, 161 profile: bool = False, 162 dump_exported_program: bool = False, 163 artifacts_dir: str | os.PathLike = ".", 164 fallback: bool = False, 165 # Deprecated options 166 training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, 167 operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX, 168 do_constant_folding: bool = True, 169 custom_opsets: Mapping[str, int] | None = None, 170 export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False, 171 autograd_inlining: bool = True, 172 **_: Any, # ignored options 173) -> Any | None: 174 r"""Exports a model into ONNX format. 175 176 Args: 177 model: The model to be exported. 178 args: Example positional inputs. Any non-Tensor arguments will be hard-coded into the 179 exported model; any Tensor arguments will become inputs of the exported model, 180 in the order they occur in the tuple. 181 f: Path to the output ONNX model file. E.g. "model.onnx". 182 kwargs: Optional example keyword inputs. 183 export_params: If false, parameters (weights) will not be exported. 184 verbose: Whether to enable verbose logging. 185 input_names: names to assign to the input nodes of the graph, in order. 186 output_names: names to assign to the output nodes of the graph, in order. 187 opset_version: The version of the 188 `default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_ 189 to target. Must be >= 7. 190 dynamic_axes: 191 192 By default the exported model will have the shapes of all input and output tensors 193 set to exactly match those given in ``args``. To specify axes of tensors as 194 dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema: 195 196 * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or 197 ``output_names``. 198 * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a 199 list, each element is an axis index. 200 201 For example:: 202 203 class SumModule(torch.nn.Module): 204 def forward(self, x): 205 return torch.sum(x, dim=1) 206 207 208 torch.onnx.export( 209 SumModule(), 210 (torch.ones(2, 2),), 211 "onnx.pb", 212 input_names=["x"], 213 output_names=["sum"], 214 ) 215 216 Produces:: 217 218 input { 219 name: "x" 220 ... 221 shape { 222 dim { 223 dim_value: 2 # axis 0 224 } 225 dim { 226 dim_value: 2 # axis 1 227 ... 228 output { 229 name: "sum" 230 ... 231 shape { 232 dim { 233 dim_value: 2 # axis 0 234 ... 235 236 While:: 237 238 torch.onnx.export( 239 SumModule(), 240 (torch.ones(2, 2),), 241 "onnx.pb", 242 input_names=["x"], 243 output_names=["sum"], 244 dynamic_axes={ 245 # dict value: manually named axes 246 "x": {0: "my_custom_axis_name"}, 247 # list value: automatic names 248 "sum": [0], 249 }, 250 ) 251 252 Produces:: 253 254 input { 255 name: "x" 256 ... 257 shape { 258 dim { 259 dim_param: "my_custom_axis_name" # axis 0 260 } 261 dim { 262 dim_value: 2 # axis 1 263 ... 264 output { 265 name: "sum" 266 ... 267 shape { 268 dim { 269 dim_param: "sum_dynamic_axes_1" # axis 0 270 ... 271 272 keep_initializers_as_inputs: If True, all the 273 initializers (typically corresponding to model weights) in the 274 exported graph will also be added as inputs to the graph. If False, 275 then initializers are not added as inputs to the graph, and only 276 the user inputs are added as inputs. 277 278 Set this to True if you intend to supply model weights at runtime. 279 Set it to False if the weights are static to allow for better optimizations 280 (e.g. constant folding) by backends/runtimes. 281 282 dynamo: Whether to export the model with ``torch.export`` ExportedProgram instead of TorchScript. 283 external_data: Whether to save the model weights as an external data file. 284 This is required for models with large weights that exceed the ONNX file size limit (2GB). 285 When False, the weights are saved in the ONNX file with the model architecture. 286 dynamic_shapes: A dictionary of dynamic shapes for the model inputs. Refer to 287 :func:`torch.export.export` for more details. This is only used (and preferred) when dynamo is True. 288 Only one parameter `dynamic_axes` or `dynamic_shapes` should be set 289 at the same time. 290 report: Whether to generate a markdown report for the export process. 291 verify: Whether to verify the exported model using ONNX Runtime. 292 profile: Whether to profile the export process. 293 dump_exported_program: Whether to dump the :class:`torch.export.ExportedProgram` to a file. 294 This is useful for debugging the exporter. 295 artifacts_dir: The directory to save the debugging artifacts like the report and the serialized 296 exported program. 297 fallback: Whether to fallback to the TorchScript exporter if the dynamo exporter fails. 298 299 training: Deprecated option. Instead, set the training mode of the model before exporting. 300 operator_export_type: Deprecated option. Only ONNX is supported. 301 do_constant_folding: Deprecated option. The exported graph is always optimized. 302 custom_opsets: Deprecated. 303 A dictionary: 304 305 * KEY (str): opset domain name 306 * VALUE (int): opset version 307 308 If a custom opset is referenced by ``model`` but not mentioned in this dictionary, 309 the opset version is set to 1. Only custom opset domain name and version should be 310 indicated through this argument. 311 export_modules_as_functions: Deprecated option. 312 313 Flag to enable 314 exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the 315 particular types of modules to export as local functions in ONNX. 316 This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because 317 ``opset_version`` < 15 implies IR version < 8, which means no local function support. 318 Module variables will be exported as function attributes. There are two categories of function 319 attributes. 320 321 1. Annotated attributes: class variables that have type annotations via 322 `PEP 526-style <https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations>`_ 323 will be exported as attributes. 324 Annotated attributes are not used inside the subgraph of ONNX local function because 325 they are not created by PyTorch JIT tracing, but they may be used by consumers 326 to determine whether or not to replace the function with a particular fused kernel. 327 328 2. Inferred attributes: variables that are used by operators inside the module. Attribute names 329 will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from 330 python module annotations. Inferred attributes are used inside the subgraph of ONNX local function. 331 332 * ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes. 333 * ``True``: export all ``nn.Module`` forward calls as local function nodes. 334 * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes, 335 only if the type of the ``nn.Module`` is found in the set. 336 autograd_inlining: Deprecated. 337 Flag used to control whether to inline autograd functions. 338 Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. 339 """ 340 if dynamo is True or isinstance(model, torch.export.ExportedProgram): 341 from torch.onnx._internal import exporter 342 343 if isinstance(args, torch.Tensor): 344 args = (args,) 345 return exporter.export_compat( 346 model, 347 args, 348 f, 349 kwargs=kwargs, 350 export_params=export_params, 351 verbose=verbose, 352 input_names=input_names, 353 output_names=output_names, 354 opset_version=opset_version, 355 dynamic_axes=dynamic_axes, 356 keep_initializers_as_inputs=keep_initializers_as_inputs, 357 external_data=external_data, 358 dynamic_shapes=dynamic_shapes, 359 report=report, 360 verify=verify, 361 profile=profile, 362 dump_exported_program=dump_exported_program, 363 artifacts_dir=artifacts_dir, 364 fallback=fallback, 365 ) 366 else: 367 from torch.onnx.utils import export 368 369 if dynamic_shapes: 370 raise ValueError( 371 "The exporter only supports dynamic shapes " 372 "through parameter dynamic_axes when dynamo=False." 373 ) 374 375 export( 376 model, 377 args, 378 f, # type: ignore[arg-type] 379 kwargs=kwargs, 380 export_params=export_params, 381 verbose=verbose is True, 382 input_names=input_names, 383 output_names=output_names, 384 opset_version=opset_version, 385 dynamic_axes=dynamic_axes, 386 keep_initializers_as_inputs=keep_initializers_as_inputs, 387 training=training, 388 operator_export_type=operator_export_type, 389 do_constant_folding=do_constant_folding, 390 custom_opsets=custom_opsets, 391 export_modules_as_functions=export_modules_as_functions, 392 autograd_inlining=autograd_inlining, 393 ) 394 return None 395 396 397def dynamo_export( 398 model: torch.nn.Module | Callable | torch.export.ExportedProgram, # type: ignore[name-defined] 399 /, 400 *model_args, 401 export_options: ExportOptions | None = None, 402 **model_kwargs, 403) -> ONNXProgram | Any: 404 """Export a torch.nn.Module to an ONNX graph. 405 406 Args: 407 model: The PyTorch model to be exported to ONNX. 408 model_args: Positional inputs to ``model``. 409 model_kwargs: Keyword inputs to ``model``. 410 export_options: Options to influence the export to ONNX. 411 412 Returns: 413 An in-memory representation of the exported ONNX model. 414 415 **Example 1 - Simplest export** 416 :: 417 418 class MyModel(torch.nn.Module): 419 def __init__(self) -> None: 420 super().__init__() 421 self.linear = torch.nn.Linear(2, 2) 422 423 def forward(self, x, bias=None): 424 out = self.linear(x) 425 out = out + bias 426 return out 427 428 429 model = MyModel() 430 kwargs = {"bias": 3.0} 431 args = (torch.randn(2, 2, 2),) 432 onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save( 433 "my_simple_model.onnx" 434 ) 435 436 **Example 2 - Exporting with dynamic shapes** 437 :: 438 439 # The previous model can be exported with dynamic shapes 440 export_options = torch.onnx.ExportOptions(dynamic_shapes=True) 441 onnx_program = torch.onnx.dynamo_export( 442 model, *args, **kwargs, export_options=export_options 443 ) 444 onnx_program.save("my_dynamic_model.onnx") 445 """ 446 447 # NOTE: The new exporter is experimental and is not enabled by default. 448 import warnings 449 450 from torch.onnx import _flags 451 from torch.onnx._internal import exporter 452 from torch.utils import _pytree 453 454 if isinstance(model, torch.export.ExportedProgram): 455 return exporter.export_compat( 456 model, # type: ignore[arg-type] 457 model_args, 458 f=None, 459 kwargs=model_kwargs, 460 opset_version=18, 461 external_data=True, 462 export_params=True, 463 fallback=True, 464 ) 465 elif _flags.USE_EXPERIMENTAL_LOGIC: 466 if export_options is not None: 467 warnings.warn( 468 "You are using an experimental ONNX export logic, which currently only supports dynamic shapes. " 469 "For a more comprehensive set of export options, including advanced features, please consider using " 470 "`torch.onnx.export(..., dynamo=True)`. ", 471 category=FutureWarning, 472 ) 473 474 if export_options is not None and export_options.dynamic_shapes: 475 # Make all shapes dynamic 476 def _to_dynamic_shapes_mapper(): 477 arg_order = 0 478 479 def _to_dynamic_shape(x): 480 nonlocal arg_order 481 if isinstance(x, torch.Tensor): 482 rank = len(x.shape) 483 dynamic_shape = {} 484 for i in range(rank): 485 dynamic_shape[i] = torch.export.Dim( 486 f"arg_{arg_order}_dim_{i}" 487 ) 488 arg_order += 1 489 return dynamic_shape 490 else: 491 return None 492 493 return _to_dynamic_shape 494 495 # model_args could be nested 496 dynamic_shapes = _pytree.tree_map( 497 _to_dynamic_shapes_mapper(), 498 model_args, 499 ) 500 else: 501 dynamic_shapes = None 502 503 return exporter.export_compat( 504 model, # type: ignore[arg-type] 505 model_args, 506 f=None, 507 kwargs=model_kwargs, 508 dynamic_shapes=dynamic_shapes, 509 opset_version=18, 510 external_data=True, 511 export_params=True, 512 fallback=True, 513 ) 514 else: 515 from torch.onnx._internal._exporter_legacy import dynamo_export 516 517 return dynamo_export( 518 model, *model_args, export_options=export_options, **model_kwargs 519 ) 520 521 522# TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module. 523 524# Returns True iff ONNX logging is turned on. 525is_onnx_log_enabled = _C._jit_is_onnx_log_enabled 526 527 528def enable_log() -> None: 529 r"""Enables ONNX logging.""" 530 _C._jit_set_onnx_log_enabled(True) 531 532 533def disable_log() -> None: 534 r"""Disables ONNX logging.""" 535 _C._jit_set_onnx_log_enabled(False) 536 537 538"""Sets output stream for ONNX logging. 539 540Args: 541 stream_name (str, default "stdout"): Only 'stdout' and 'stderr' are supported 542 as ``stream_name``. 543""" 544set_log_stream = _C._jit_set_onnx_log_output_stream 545 546 547"""A simple logging facility for ONNX exporter. 548 549Args: 550 args: Arguments are converted to string, concatenated together with a newline 551 character appended to the end, and flushed to output stream. 552""" 553log = _C._jit_onnx_log 554