1"""Compatibility functions for the torch.onnx.export API.""" 2 3# mypy: allow-untyped-defs 4# mypy: disable-error-code=attr-defined 5from __future__ import annotations 6 7import inspect 8import logging 9from typing import Any, Mapping, Sequence, TYPE_CHECKING 10 11import torch 12from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir 13from torch.onnx._internal.exporter import _core, _onnx_program 14 15 16if TYPE_CHECKING: 17 import os 18 19logger = logging.getLogger(__name__) 20 21 22def _signature(model) -> inspect.Signature: 23 should_be_callable = getattr(model, "forward", model) 24 if callable(should_be_callable): 25 return inspect.signature(should_be_callable) 26 raise ValueError("model has no forward method and is not callable") 27 28 29def _from_dynamic_axes_to_dynamic_shapes( 30 model, 31 *, 32 dynamic_axes=None, 33 output_names: set[str], 34 input_names: Sequence[str] | None = None, 35) -> dict[str, Any] | None: 36 """ 37 38 dynamic_axes examples: 39 (1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}} 40 (2) dynamic_axes = {"x": [0], "y": [1]} 41 42 these will be converted to dynamic_shapes respectively: 43 (1) dynamic_shapes = {"x": {0: Dim("my_custom_axis_name_1")}, "y": {1: Dim("my_custom_axis_name_2")}} 44 (2) dynamic_shapes = {"x": {0: Dim("x_dim_0")}, "y": {1: Dim("y_dim_1")}} # auto-generated dim names 45 46 """ 47 # https://github.com/pytorch/pytorch/pull/128371 48 # 1. The function does not need to provide dynamic_shapes to torch.export.export 49 if dynamic_axes is None: 50 return None 51 52 if input_names is None: 53 input_names = [] 54 55 sig = _signature(model) 56 if len(input_names) > len(sig.parameters): 57 raise ValueError( 58 f"Number of input names ({len(input_names)}) should not be greater than " 59 f"the number of model inputs ({len(sig.parameters)})" 60 ) 61 input_names_to_model_inputs = {} 62 for idx, param_name in enumerate(sig.parameters): 63 if idx < len(input_names): 64 input_names_to_model_inputs[input_names[idx]] = param_name 65 else: 66 input_names_to_model_inputs[param_name] = param_name 67 68 # NOTE: torch.export.export does not support input names assignment, 69 # so we need to map input names to model inputs to create dynamic_shapes 70 # for the exported program 71 dynamic_shapes_to_exported_program = {} 72 for input_name, axes in dynamic_axes.items(): 73 if input_name in output_names: 74 # User specified an output name as a dynamic axis, so we skip it 75 continue 76 # input_name can be either from input_names or from the model inputs 77 if input_name not in input_names_to_model_inputs: 78 raise ValueError( 79 f"dynamic axis: {input_name} is not found in the input names: {input_names}" 80 ) 81 model_input_name = input_names_to_model_inputs[input_name] 82 if isinstance(axes, dict): 83 dynamic_shapes_to_exported_program[model_input_name] = { 84 k: torch.export.Dim(v) for k, v in axes.items() 85 } 86 elif isinstance(axes, list): 87 dynamic_shapes_to_exported_program[model_input_name] = { 88 k: torch.export.Dim(f"{model_input_name}_dim_{k}") for k in axes 89 } 90 else: 91 raise TypeError( 92 f"dynamic_axes value must be either a dict or a list, but got {type(axes)}" 93 ) 94 # torch.export.export needs static dim to present in dynamic_shapes 95 # for all input tensors, so we need to add them with None 96 for input_name in sig.parameters: 97 if input_name not in dynamic_shapes_to_exported_program: 98 dynamic_shapes_to_exported_program[input_name] = None # type: ignore[assignment] 99 100 return dynamic_shapes_to_exported_program 101 102 103def _get_torch_export_args( 104 args: tuple[Any, ...], 105 kwargs: dict[str, Any] | None, 106) -> tuple[tuple[Any, ...], dict[str, Any] | None]: 107 """Obtain the arguments for torch.onnx.export from the model and the input arguments.""" 108 if not kwargs and args and isinstance(args[-1], dict): 109 kwargs = args[-1] 110 args = args[:-1] 111 return args, kwargs 112 113 114def export_compat( 115 model: torch.nn.Module 116 | torch.export.ExportedProgram 117 | torch.jit.ScriptModule 118 | torch.jit.ScriptFunction, 119 args: tuple[Any, ...], 120 f: str | os.PathLike | None = None, 121 *, 122 kwargs: dict[str, Any] | None = None, 123 export_params: bool = True, 124 verbose: bool | None = None, 125 input_names: Sequence[str] | None = None, 126 output_names: Sequence[str] | None = None, 127 opset_version: int | None = None, 128 dynamic_axes: Mapping[str, Mapping[int, str]] 129 | Mapping[str, Sequence[int]] 130 | None = None, 131 dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, 132 keep_initializers_as_inputs: bool = False, 133 external_data: bool = True, 134 report: bool = False, 135 verify: bool = False, 136 profile: bool = False, 137 dump_exported_program: bool = False, 138 artifacts_dir: str | os.PathLike = ".", 139 fallback: bool = False, 140 **_, 141) -> _onnx_program.ONNXProgram: 142 if opset_version is None: 143 # TODO(justinchuby): Change the hardcoded opset version for it to be flexible 144 opset_version = 18 145 146 if isinstance(model, torch.export.ExportedProgram): 147 # We know the model is already exported program, so the args, kwargs, and dynamic_shapes 148 # are not used 149 dynamic_shapes = dynamic_shapes or {} 150 else: 151 args, kwargs = _get_torch_export_args(args, kwargs) 152 if dynamic_shapes is None and dynamic_axes is not None: 153 dynamic_shapes = _from_dynamic_axes_to_dynamic_shapes( 154 model, 155 dynamic_axes=dynamic_axes, 156 input_names=input_names, 157 output_names=set(output_names or ()), 158 ) 159 160 try: 161 onnx_program = _core.export( 162 model, 163 args, 164 kwargs, 165 registry=None, 166 dynamic_shapes=dynamic_shapes, 167 input_names=input_names, 168 output_names=output_names, 169 profile=profile, 170 report=report, 171 verify=verify, 172 dump_exported_program=dump_exported_program, 173 artifacts_dir=artifacts_dir, 174 verbose=verbose, 175 ) 176 177 except Exception as e: 178 if fallback: 179 if verbose is not False: 180 print( 181 "[torch.onnx] Falling back to legacy torch.onnx.export due " 182 f"to the following error: {e}", 183 ) 184 if f is None: 185 raise TypeError("f must be provided when fallback is enabled") from e 186 torch.onnx.utils.export( 187 model, # type: ignore[arg-type] 188 args, 189 f, # type: ignore[arg-type] 190 kwargs=kwargs, 191 export_params=export_params, 192 input_names=input_names, 193 output_names=output_names, 194 opset_version=17, # TODO(justinchuby): Hard coded to 17 for now 195 dynamic_axes=dynamic_axes, 196 keep_initializers_as_inputs=keep_initializers_as_inputs, 197 ) 198 onnx_program = _onnx_program.ONNXProgram(ir.load(f), None) 199 else: 200 raise 201 202 # Converter opset version and optimize 203 onnx_program.model = onnxscript_apis.convert_version( 204 onnx_program.model, opset_version 205 ) 206 onnx_program.model = onnxscript_apis.optimize(onnx_program.model) 207 208 if f is not None: 209 onnx_program.save( 210 f, 211 include_initializers=export_params, 212 keep_initializers_as_inputs=keep_initializers_as_inputs, 213 external_data=external_data, 214 ) 215 216 return onnx_program 217